Introduce OAuth2AuthorizationRequest.attributes

Fixes gh-5940
This commit is contained in:
Joe Grandja 2019-02-05 10:17:46 -05:00
parent 67fb936c7e
commit 594a169798
13 changed files with 108 additions and 82 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -146,14 +146,14 @@ public class OAuth2ClientConfigurerTests {
this.spring.register(OAuth2ClientConfig.class).autowire();
// Setup the Authorization Request in the session
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri())
.clientId(this.registration1.getClientId())
.redirectUri("http://localhost/client-1")
.state("state")
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -473,7 +473,7 @@ public class OAuth2LoginConfigurerTests {
.clientId(registration.getClientId())
.state("state123")
.redirectUri("http://localhost")
.additionalParameters(
.attributes(
Collections.singletonMap(
OAuth2ParameterNames.REGISTRATION_ID,
registration.getRegistrationId()))

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -115,8 +115,8 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = builder
.clientId(clientRegistration.getClientId())
@ -124,7 +124,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
return authorizationRequest;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -161,7 +161,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter {
OAuth2AuthorizationRequest authorizationRequest =
this.authorizationRequestRepository.removeAuthorizationRequest(request, response);
String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -165,7 +165,7 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String registrationId = (String) authorizationRequest.getAdditionalParameters().get(OAuth2ParameterNames.REGISTRATION_ID);
String registrationId = authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
if (clientRegistration == null) {
OAuth2Error oauth2Error = new OAuth2Error(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE,

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -118,9 +118,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
String redirectUriStr = this
.expandRedirectUri(exchange.getRequest(), clientRegistration);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID,
clientRegistration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
OAuth2AuthorizationRequest.Builder builder;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
@ -139,7 +138,7 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -85,9 +85,9 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter
private Mono<OAuth2AuthorizationCodeAuthenticationToken> authenticationRequest(ServerWebExchange exchange, OAuth2AuthorizationRequest authorizationRequest) {
return Mono.just(authorizationRequest)
.map(OAuth2AuthorizationRequest::getAdditionalParameters)
.flatMap(additionalParams -> {
String id = (String) additionalParams.get(OAuth2ParameterNames.REGISTRATION_ID);
.map(OAuth2AuthorizationRequest::getAttributes)
.flatMap(attributes -> {
String id = (String) attributes.get(OAuth2ParameterNames.REGISTRATION_ID);
if (id == null) {
return oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -105,7 +105,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
.isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes());
assertThat(authorizationRequest.getState()).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OAuth2ParameterNames.REGISTRATION_ID);
assertThat(authorizationRequest.getAttributes())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
assertThat(authorizationRequest.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" +
@ -123,7 +124,7 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
assertThat(authorizationRequest.getAttributes())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -402,15 +402,15 @@ public class OAuth2LoginAuthenticationFilterTests {
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
ClientRegistration registration, String state) {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
.clientId(registration.getClientId())
.redirectUri(expandRedirectUri(request, registration))
.scopes(registration.getScopes())
.state(state)
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -74,7 +74,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest {
.clientId("client-id")
.redirectUri("http://localhost/client-1")
.state("state")
.additionalParameters(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId));
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId));
private final MockServerHttpRequest.BaseBuilder<?> request = MockServerHttpRequest.get("/");
@ -95,8 +95,8 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTest {
}
@Test
public void applyWhenAdditionalParametersMissingThenOAuth2AuthorizationException() {
this.authorizationRequest.additionalParameters(Collections.emptyMap());
public void applyWhenAttributesMissingThenOAuth2AuthorizationException() {
this.authorizationRequest.attributes(Collections.emptyMap());
when(this.authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(this.authorizationRequest.build()));
assertThatThrownBy(() -> applyConverter())

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,6 +15,15 @@
*/
package org.springframework.security.oauth2.core.endpoint;
import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@ -25,15 +34,6 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
/**
* A representation of an OAuth 2.0 Authorization Request
* for the authorization code grant type or implicit grant type.
@ -56,6 +56,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private String state;
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Map<String, Object> attributes;
private OAuth2AuthorizationRequest() {
}
@ -132,6 +133,29 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return this.additionalParameters;
}
/**
* Returns the attributes associated to the request.
*
* @since 5.2
* @return a {@code Map} of the attributes associated to the request
*/
public Map<String, Object> getAttributes() {
return this.attributes;
}
/**
* Returns the value of an attribute associated to the request, or {@code null} if not available.
*
* @since 5.2
* @param name the name of the attribute
* @param <T> the type of the attribute
* @return the value of the attribute associated to the request
*/
@SuppressWarnings("unchecked")
public <T> T getAttribute(String name) {
return (T) this.getAttributes().get(name);
}
/**
* Returns the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
@ -181,7 +205,8 @@ public final class OAuth2AuthorizationRequest implements Serializable {
.redirectUri(authorizationRequest.getRedirectUri())
.scopes(authorizationRequest.getScopes())
.state(authorizationRequest.getState())
.additionalParameters(authorizationRequest.getAdditionalParameters());
.additionalParameters(authorizationRequest.getAdditionalParameters())
.attributes(authorizationRequest.getAttributes());
}
/**
@ -197,6 +222,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private String state;
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Map<String, Object> attributes;
private Builder(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
@ -288,6 +314,18 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return this;
}
/**
* Sets the attributes associated to the request.
*
* @since 5.2
* @param attributes the attributes associated to the request
* @return the {@link Builder}
*/
public Builder attributes(Map<String, Object> attributes) {
this.attributes = attributes;
return this;
}
/**
* Sets the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
@ -332,6 +370,9 @@ public final class OAuth2AuthorizationRequest implements Serializable {
authorizationRequest.authorizationRequestUri =
StringUtils.hasText(this.authorizationRequestUri) ?
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
authorizationRequest.attributes = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.attributes) ?
Collections.emptyMap() : new LinkedHashMap<>(this.attributes));
return authorizationRequest;
}
@ -351,9 +392,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
parameters.set(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
}
if (!CollectionUtils.isEmpty(this.additionalParameters)) {
this.additionalParameters.entrySet().stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.REGISTRATION_ID))
.forEach(e -> parameters.set(e.getKey(), e.getValue().toString()));
this.additionalParameters.forEach((k, v) -> parameters.set(k, v.toString()));
}
return UriComponentsBuilder.fromHttpUrl(this.authorizationUri)

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,19 +15,16 @@
*/
package org.springframework.security.oauth2.core.endpoint;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.junit.Test;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.*;
/**
* Tests for {@link OAuth2AuthorizationRequest}.
@ -166,6 +163,10 @@ public class OAuth2AuthorizationRequestTests {
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
Map<String, Object> attributes = new HashMap<>();
attributes.put("attribute1", "value1");
attributes.put("attribute2", "value2");
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
@ -173,6 +174,7 @@ public class OAuth2AuthorizationRequestTests {
.scopes(SCOPES)
.state(STATE)
.additionalParameters(additionalParameters)
.attributes(attributes)
.authorizationRequestUri(AUTHORIZATION_URI)
.build();
@ -184,6 +186,7 @@ public class OAuth2AuthorizationRequestTests {
assertThat(authorizationRequest.getScopes()).isEqualTo(SCOPES);
assertThat(authorizationRequest.getState()).isEqualTo(STATE);
assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters);
assertThat(authorizationRequest.getAttributes()).isEqualTo(attributes);
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
}
@ -250,28 +253,6 @@ public class OAuth2AuthorizationRequestTests {
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id");
}
@Test
public void buildWhenAuthorizationRequestIncludesRegistrationIdParameterThenAuthorizationRequestUriDoesNotIncludeRegistrationIdParameter() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, "registration1");
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI + "?rparam1=rvalue1&rparam2=rvalue2")
.scopes(SCOPES)
.state(STATE)
.additionalParameters(additionalParameters)
.build();
assertThat(authorizationRequest.getAuthorizationRequestUri())
.isEqualTo("https://provider.com/oauth2/authorize?" +
"response_type=code&client_id=client-id&" +
"scope=scope1%20scope2&state=state&" +
"redirect_uri=http://example.com?rparam1%3Drvalue1%26rparam2%3Drvalue2&param1=value1");
}
@Test
public void fromWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2AuthorizationRequest.from(null)).isInstanceOf(IllegalArgumentException.class);
@ -283,6 +264,10 @@ public class OAuth2AuthorizationRequestTests {
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
Map<String, Object> attributes = new HashMap<>();
attributes.put("attribute1", "value1");
attributes.put("attribute2", "value2");
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
@ -290,6 +275,7 @@ public class OAuth2AuthorizationRequestTests {
.scopes(SCOPES)
.state(STATE)
.additionalParameters(additionalParameters)
.attributes(attributes)
.build();
OAuth2AuthorizationRequest authorizationRequestCopy =
@ -303,6 +289,7 @@ public class OAuth2AuthorizationRequestTests {
assertThat(authorizationRequestCopy.getScopes()).isEqualTo(authorizationRequest.getScopes());
assertThat(authorizationRequestCopy.getState()).isEqualTo(authorizationRequest.getState());
assertThat(authorizationRequestCopy.getAdditionalParameters()).isEqualTo(authorizationRequest.getAdditionalParameters());
assertThat(authorizationRequestCopy.getAttributes()).isEqualTo(authorizationRequest.getAttributes());
assertThat(authorizationRequestCopy.getAuthorizationRequestUri()).isEqualTo(authorizationRequest.getAuthorizationRequestUri());
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,13 +27,13 @@ public class TestOAuth2AuthorizationRequests {
public static OAuth2AuthorizationRequest.Builder request() {
String registrationId = "registration-id";
String clientId = "client-id";
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, registrationId);
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registrationId);
return OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/login/oauth/authorize")
.clientId(clientId)
.redirectUri("https://example.com/authorize/oauth2/code/registration-id")
.state("state")
.additionalParameters(additionalParameters);
.attributes(attributes);
}
}