Add support for custom authorization request parameters

Fixes gh-4911
This commit is contained in:
Joe Grandja 2018-06-30 20:21:02 -04:00 committed by Rob Winch
parent a3210c96d9
commit 779597af2a
12 changed files with 867 additions and 423 deletions

View File

@ -120,7 +120,7 @@ public class OAuth2ClientConfigurerTests {
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
.andExpect(status().is3xxRedirection())
.andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/client-1");
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1");
}
@Test
@ -168,7 +168,7 @@ public class OAuth2ClientConfigurerTests {
MvcResult mvcResult = this.mockMvc.perform(get("/resource1").with(user("user1")))
.andExpect(status().is3xxRedirection())
.andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/client-1");
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1");
verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
}

View File

@ -0,0 +1,169 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.http.HttpServletRequest;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
/**
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
* resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
* using the default request {@code URI} pattern {@code /oauth2/authorization/{registrationId}}.
*
* <p>
* <b>NOTE:</b> The default base {@code URI} {@code /oauth2/authorization} may be overridden
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
*
* @author Joe Grandja
* @since 5.1
* @see OAuth2AuthorizationRequestResolver
* @see OAuth2AuthorizationRequestRedirectFilter
*/
public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {
private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
private final ClientRegistrationRepository clientRegistrationRepository;
private final AntPathRequestMatcher authorizationRequestMatcher;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
/**
* Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
* @param authorizationRequestBaseUri the base {@code URI} used for resolving authorization requests
*/
public DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository clientRegistrationRepository,
String authorizationRequestBaseUri) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizationRequestMatcher = new AntPathRequestMatcher(
authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
}
@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
String registrationId = this.resolveRegistrationId(request);
if (registrationId == null) {
return null;
}
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
if (clientRegistration == null) {
throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId);
}
OAuth2AuthorizationRequest.Builder builder;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
builder = OAuth2AuthorizationRequest.authorizationCode();
} else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
builder = OAuth2AuthorizationRequest.implicit();
} else {
throw new IllegalArgumentException("Invalid Authorization Grant Type (" +
clientRegistration.getAuthorizationGrantType().getValue() +
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
}
String redirectUriAction = this.resolveRedirectUriAction(request, clientRegistration);
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.build();
return authorizationRequest;
}
private String resolveRegistrationId(HttpServletRequest request) {
// Check for ClientAuthorizationRequiredException which may have been set
// in the request by OAuth2AuthorizationRequestRedirectFilter
ClientAuthorizationRequiredException authzEx =
(ClientAuthorizationRequiredException) request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
if (authzEx != null) {
return authzEx.getClientRegistrationId();
}
if (this.authorizationRequestMatcher.matches(request)) {
return this.authorizationRequestMatcher
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
}
return null;
}
private String resolveRedirectUriAction(HttpServletRequest request, ClientRegistration clientRegistration) {
String action = null;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
String loginAction = "login";
String authorizeAction = "authorize";
String actionParameter = request.getParameter("action");
if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
// Check for ClientAuthorizationRequiredException which may have been set
// in the request by OAuth2AuthorizationRequestRedirectFilter
action = authorizeAction;
} else if (actionParameter == null) {
action = loginAction; // Default
} else {
if (actionParameter.equalsIgnoreCase(loginAction)) {
action = loginAction;
} else {
action = authorizeAction;
}
}
}
return action;
}
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
// Supported URI variables -> baseUrl, action, registrationId
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
Map<String, String> uriVariables = new HashMap<>();
uriVariables.put("registrationId", clientRegistration.getRegistrationId());
String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath())
.build()
.toUriString();
uriVariables.put("baseUrl", baseUrl);
if (action != null) {
uriVariables.put("action", action);
}
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
.buildAndExpand(uriVariables)
.toUriString();
}
}

View File

@ -16,34 +16,24 @@
package org.springframework.security.oauth2.client.web;
import org.springframework.http.HttpStatus;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.ThrowableAnalyzer;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.URI;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
/**
* This {@code Filter} initiates the authorization code grant or implicit grant flow
@ -58,19 +48,24 @@ import java.util.Map;
*
* <p>
* By default, this {@code Filter} responds to authorization requests
* at the {@code URI} {@code /oauth2/authorization/{registrationId}}.
* at the {@code URI} {@code /oauth2/authorization/{registrationId}}
* using the default {@link OAuth2AuthorizationRequestResolver}.
* The {@code URI} template variable {@code {registrationId}} represents the
* {@link ClientRegistration#getRegistrationId() registration identifier} of the client
* that is used for initiating the OAuth 2.0 Authorization Request.
*
* <p>
* <b>NOTE:</b> The default base {@code URI} {@code /oauth2/authorization} may be overridden
* via it's constructor {@link #OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository, String)}.
* The default base {@code URI} {@code /oauth2/authorization} may be overridden
* via the constructor {@link #OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository, String)},
* or alternatively, an {@code OAuth2AuthorizationRequestResolver} may be provided to the constructor
* {@link #OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolver)}
* to override the resolving of authorization requests.
* @author Joe Grandja
* @author Rob Winch
* @since 5.0
* @see OAuth2AuthorizationRequest
* @see OAuth2AuthorizationRequestResolver
* @see AuthorizationRequestRepository
* @see ClientRegistration
* @see ClientRegistrationRepository
@ -84,18 +79,14 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
* The default base {@code URI} used for authorization requests.
*/
public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
private static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
private final AntPathRequestMatcher authorizationRequestMatcher;
private final ClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizationRequestUriBuilder authorizationRequestUriBuilder = new OAuth2AuthorizationRequestUriBuilder();
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
new HttpSessionOAuth2AuthorizationRequestRepository();
private RequestCache requestCache = new HttpSessionRequestCache();
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
/**
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters.
@ -112,14 +103,23 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
* @param clientRegistrationRepository the repository of client registrations
* @param authorizationRequestBaseUri the base {@code URI} used for authorization requests
*/
public OAuth2AuthorizationRequestRedirectFilter(
ClientRegistrationRepository clientRegistrationRepository, String authorizationRequestBaseUri) {
Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
public OAuth2AuthorizationRequestRedirectFilter(ClientRegistrationRepository clientRegistrationRepository,
String authorizationRequestBaseUri) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
this.authorizationRequestMatcher = new AntPathRequestMatcher(
authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
this.clientRegistrationRepository = clientRegistrationRepository;
Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
this.authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
clientRegistrationRepository, authorizationRequestBaseUri);
}
/**
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters.
*
* @since 5.1
* @param authorizationRequestResolver the resolver used for resolving authorization requests
*/
public OAuth2AuthorizationRequestRedirectFilter(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
this.authorizationRequestResolver = authorizationRequestResolver;
}
/**
@ -147,12 +147,14 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if (this.shouldRequestAuthorization(request, response)) {
try {
this.sendRedirectForAuthorization(request, response);
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
if (authorizationRequest != null) {
this.sendRedirectForAuthorization(request, response, authorizationRequest);
return;
}
} catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed);
}
return;
}
@ -168,7 +170,11 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
if (authzEx != null) {
try {
request.setAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME, authzEx);
this.sendRedirectForAuthorization(request, response, authzEx.getClientRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
if (authorizationRequest == null) {
throw authzEx;
}
this.sendRedirectForAuthorization(request, response, authorizationRequest);
this.requestCache.saveRequest(request, response);
} catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed);
@ -188,61 +194,13 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
}
}
private boolean shouldRequestAuthorization(HttpServletRequest request, HttpServletResponse response) {
return this.authorizationRequestMatcher.matches(request);
}
private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
String registrationId = this.authorizationRequestMatcher
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
this.sendRedirectForAuthorization(request, response, registrationId);
}
private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
String registrationId) throws IOException, ServletException {
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId);
if (clientRegistration == null) {
throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId);
}
this.sendRedirectForAuthorization(request, response, clientRegistration);
}
private void sendRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
ClientRegistration clientRegistration) throws IOException, ServletException {
String redirectUriStr = this.expandRedirectUri(request, clientRegistration);
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
OAuth2AuthorizationRequest.Builder builder;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
builder = OAuth2AuthorizationRequest.authorizationCode();
} else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType())) {
builder = OAuth2AuthorizationRequest.implicit();
} else {
throw new IllegalArgumentException("Invalid Authorization Grant Type (" +
clientRegistration.getAuthorizationGrantType().getValue() +
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
}
OAuth2AuthorizationRequest authorizationRequest = builder
.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr)
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.additionalParameters(additionalParameters)
.build();
OAuth2AuthorizationRequest authorizationRequest) throws IOException, ServletException {
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationRequest.getGrantType())) {
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
}
URI redirectUri = this.authorizationRequestUriBuilder.build(authorizationRequest);
this.authorizationRedirectStrategy.sendRedirect(request, response, redirectUri.toString());
this.authorizationRedirectStrategy.sendRedirect(request, response, authorizationRequest.getAuthorizationRequestUri());
}
private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
@ -254,43 +212,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(), HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration) {
// Supported URI variables -> baseUrl, action, registrationId
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
Map<String, String> uriVariables = new HashMap<>();
uriVariables.put("registrationId", clientRegistration.getRegistrationId());
String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath())
.build()
.toUriString();
uriVariables.put("baseUrl", baseUrl);
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
String loginAction = "login";
String authorizeAction = "authorize";
String actionParameter = "action";
String action;
if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
action = authorizeAction;
} else if (request.getParameter(actionParameter) == null) {
action = loginAction;
} else {
String actionValue = request.getParameter(actionParameter);
if (loginAction.equalsIgnoreCase(actionValue)) {
action = loginAction;
} else {
action = authorizeAction;
}
}
uriVariables.put("action", action);
}
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
.buildAndExpand(uriVariables)
.toUriString();
}
private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer {
protected void initExtractorMap() {
super.initExtractorMap();

View File

@ -87,7 +87,6 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
private final ServerWebExchangeMatcher authorizationRequestMatcher;
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final OAuth2AuthorizationRequestUriBuilder authorizationRequestUriBuilder = new OAuth2AuthorizationRequestUriBuilder();
private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy();
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private ReactiveAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
@ -184,7 +183,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
.saveAuthorizationRequest(authorizationRequest, exchange);
}
URI redirectUri = this.authorizationRequestUriBuilder.build(authorizationRequest);
URI redirectUri = UriComponentsBuilder
.fromUriString(authorizationRequest.getAuthorizationRequestUri())
.build(true).toUri();
return saveAuthorizationRequest
.then(this.authorizationRedirectStrategy.sendRedirect(exchange, redirectUri));
});

View File

@ -0,0 +1,43 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import javax.servlet.http.HttpServletRequest;
/**
* Implementations of this interface are capable of resolving
* an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}.
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
*
* @author Joe Grandja
* @since 5.1
* @see OAuth2AuthorizationRequest
* @see OAuth2AuthorizationRequestRedirectFilter
*/
public interface OAuth2AuthorizationRequestResolver {
/**
* Returns the {@link OAuth2AuthorizationRequest} resolved from
* the provided {@code HttpServletRequest} or {@code null} if not available.
*
* @param request the {@code HttpServletRequest}
* @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
OAuth2AuthorizationRequest resolve(HttpServletRequest request);
}

View File

@ -1,53 +0,0 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.Set;
/**
* A {@code URI} builder for an OAuth 2.0 Authorization Request.
*
* @author Joe Grandja
* @since 5.0
* @see OAuth2AuthorizationRequest
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Code Grant Request</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.2.1">Section 4.2.1 Implicit Grant Request</a>
*/
class OAuth2AuthorizationRequestUriBuilder {
URI build(OAuth2AuthorizationRequest authorizationRequest) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
Set<String> scopes = authorizationRequest.getScopes();
UriComponentsBuilder uriBuilder = UriComponentsBuilder
.fromUriString(authorizationRequest.getAuthorizationUri())
.queryParam(OAuth2ParameterNames.RESPONSE_TYPE, authorizationRequest.getResponseType().getValue())
.queryParam(OAuth2ParameterNames.CLIENT_ID, authorizationRequest.getClientId())
.queryParam(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " "))
.queryParam(OAuth2ParameterNames.STATE, authorizationRequest.getState());
if (authorizationRequest.getRedirectUri() != null) {
uriBuilder.queryParam(OAuth2ParameterNames.REDIRECT_URI, authorizationRequest.getRedirectUri());
}
return uriBuilder.build().encode().toUri();
}
}

View File

@ -0,0 +1,242 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import static org.assertj.core.api.Assertions.*;
/**
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
*
* @author Joe Grandja
*/
public class DefaultOAuth2AuthorizationRequestResolverTests {
private ClientRegistration registration1;
private ClientRegistration registration2;
private ClientRegistrationRepository clientRegistrationRepository;
private String authorizationRequestBaseUri = "/oauth2/authorization";
private DefaultOAuth2AuthorizationRequestResolver resolver;
@Before
public void setUp() {
this.registration1 = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
.scope("user")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
this.registration2 = ClientRegistration.withRegistrationId("registration-2")
.clientId("client-2")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
.scope("openid", "profile", "email")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/userinfo")
.jwkSetUri("https://provider.com/oauth2/keys")
.clientName("client-2")
.build();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
this.registration1, this.registration2);
this.resolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, this.authorizationRequestBaseUri);
}
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new DefaultOAuth2AuthorizationRequestResolver(null, this.authorizationRequestBaseUri))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNull();
}
@Test
public void resolveWhenAuthorizationRequestWithInvalidClientThenThrowIllegalArgumentException() {
ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
assertThatThrownBy(() -> this.resolver.resolve(request))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Invalid Client Registration with Id: " + clientRegistration.getRegistrationId() + "-invalid");
}
@Test
public void resolveWhenAuthorizationRequestWithValidClientThenResolves() {
ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(
clientRegistration.getProviderDetails().getAuthorizationUri());
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getClientId()).isEqualTo(clientRegistration.getClientId());
assertThat(authorizationRequest.getRedirectUri())
.isEqualTo("http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
assertThat(authorizationRequest.getScopes()).isEqualTo(clientRegistration.getScopes());
assertThat(authorizationRequest.getState()).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenResolves() {
ClientRegistration clientRegistration = this.registration2;
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setAttribute(
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
}
@Test
public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() {
ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
clientRegistration.getRedirectUriTemplate());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
"http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
}
@Test
public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() {
ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setScheme("http");
request.setServerName("example.com");
request.setServerPort(80);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Fexample.com%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() {
ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setScheme("https");
request.setServerName("example.com");
request.setServerPort(443);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=https%3A%2F%2Fexample.com%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
public void resolveWhenClientAuthorizationRequiredExceptionAvailableThenRedirectUriIsAuthorize() {
ClientRegistration clientRegistration = this.registration1;
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setAttribute(
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
public void resolveWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() {
ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-2");
}
@Test
public void resolveWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() {
ClientRegistration clientRegistration = this.registration1;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.addParameter("action", "authorize");
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
public void resolveWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() {
ClientRegistration clientRegistration = this.registration2;
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.addParameter("action", "login");
request.setServletPath(requestUri);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-2");
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,7 +17,6 @@ package org.springframework.security.oauth2.client.web;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
@ -29,16 +28,22 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.ClassUtils;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Constructor;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
/**
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
@ -100,7 +105,9 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
@Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizationRequestRedirectFilter(null))
Constructor<OAuth2AuthorizationRequestRedirectFilter> constructor = ClassUtils.getConstructorIfAvailable(
OAuth2AuthorizationRequestRedirectFilter.class, ClientRegistrationRepository.class);
assertThatThrownBy(() -> constructor.newInstance(null))
.isInstanceOf(IllegalArgumentException.class);
}
@ -110,6 +117,14 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void constructorWhenAuthorizationRequestResolverIsNullThenThrowIllegalArgumentException() {
Constructor<OAuth2AuthorizationRequestRedirectFilter> constructor = ClassUtils.getConstructorIfAvailable(
OAuth2AuthorizationRequestRedirectFilter.class, OAuth2AuthorizationRequestResolver.class);
assertThatThrownBy(() -> constructor.newInstance(null))
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setAuthorizationRequestRepository(null))
@ -165,7 +180,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-1");
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
@ -201,7 +216,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=token&client_id=client-3&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/implicit/registration-3");
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=token&client_id=client-3&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fimplicit%2Fregistration-3");
}
@Test
@ -239,75 +254,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-1");
}
@Test
public void doFilterWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpanded() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
mock(AuthorizationRequestRepository.class);
this.filter.setAuthorizationRequestRepository(authorizationRequestRepository);
this.filter.doFilter(request, response, filterChain);
ArgumentCaptor<OAuth2AuthorizationRequest> authorizationRequestArgCaptor =
ArgumentCaptor.forClass(OAuth2AuthorizationRequest.class);
verifyZeroInteractions(filterChain);
verify(authorizationRequestRepository).saveAuthorizationRequest(
authorizationRequestArgCaptor.capture(), any(HttpServletRequest.class), any(HttpServletResponse.class));
OAuth2AuthorizationRequest authorizationRequest = authorizationRequestArgCaptor.getValue();
assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(
this.registration2.getRedirectUriTemplate());
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
"http://localhost/login/oauth2/code/registration-2");
}
@Test
public void doFilterWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriExcludesPort() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setScheme("http");
request.setServerName("example.com");
request.setServerPort(80);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://example.com/login/oauth2/code/registration-1");
}
@Test
public void doFilterWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriExcludesPort() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setScheme("https");
request.setServerName("example.com");
request.setServerPort(443);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=https://example.com/login/oauth2/code/registration-1");
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1");
}
@Test
@ -325,13 +272,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME)).isNull();
}
@Test
public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownThenRedirectUriIsAuthorize() throws Exception {
public void doFilterWhenNotAuthorizationRequestAndClientAuthorizationRequiredExceptionThrownButAuthorizationRequestNotResolvedThenStatusInternalServerError() throws Exception {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
@ -341,60 +288,84 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
this.filter.doFilter(request, response, filterChain);
OAuth2AuthorizationRequestResolver resolver = req -> null;
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
}
@Test
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectUriIsLogin() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration2.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-2");
assertThat(response.getStatus()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.value());
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
// gh-4911
@Test
public void doFilterWhenAuthorizationRequestHasActionParameterAuthorizeThenRedirectUriIsAuthorize() throws Exception {
public void doFilterWhenAuthorizationRequestAndAdditionalParametersProvidedThenAuthorizationRequestIncludesAdditionalParameters() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.addParameter("action", "authorize");
request.setServletPath(requestUri);
request.addParameter("idp", "https://other.provider.com");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
OAuth2AuthorizationRequestResolver resolver = req -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("idp", req.getParameter("idp"));
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http://localhost/authorize/oauth2/code/registration-1");
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1&idp=https%3A%2F%2Fother.provider.com");
}
// gh-4911, gh-5244
@Test
public void doFilterWhenAuthorizationRequestHasActionParameterLoginThenRedirectUriIsLogin() throws Exception {
public void doFilterWhenAuthorizationRequestAndCustomAuthorizationRequestUriSetThenCustomAuthorizationRequestUriUsed() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI +
"/" + this.registration2.getRegistrationId();
"/" + this.registration1.getRegistrationId();
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.addParameter("action", "login");
request.setServletPath(requestUri);
String loginHintParamName = "login_hint";
request.addParameter(loginHintParamName, "user@provider.com");
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
OAuth2AuthorizationRequestResolver resolver = req -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put(loginHintParamName, req.getParameter(loginHintParamName));
String customAuthorizationRequestUri = UriComponentsBuilder
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
.build(true).toUriString();
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.authorizationRequestUri(customAuthorizationRequestUri)
.build();
};
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain);
verifyZeroInteractions(filterChain);
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-2&scope=openid%20profile%20email&state=.{15,}&redirect_uri=http://localhost/login/oauth2/code/registration-2");
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fregistration-1&login_hint=user@provider\\.com");
}
}

View File

@ -1,58 +0,0 @@
/*
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web;
import org.junit.Test;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link OAuth2AuthorizationRequestUriBuilder}.
*
* @author Rob Winch
* @since 5.0
*/
public class OAuth2AuthorizationRequestUriBuilderTests {
private OAuth2AuthorizationRequestUriBuilder builder = new OAuth2AuthorizationRequestUriBuilder();
@Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
this.builder.build(null);
}
@Test
public void buildWhenScopeMultiThenSeparatedByEncodedSpace() {
OAuth2AuthorizationRequest request = OAuth2AuthorizationRequest.implicit()
.additionalParameters(Collections.singletonMap("foo", "bar"))
.authorizationUri("https://idp.example.com/oauth2/v2/auth")
.clientId("client-id")
.state("thestate")
.redirectUri("https://client.example.com/login/oauth2")
.scopes(new HashSet<>(Arrays.asList("openid", "user")))
.build();
URI result = this.builder.build(request);
assertThat(result.toASCIIString()).isEqualTo("https://idp.example.com/oauth2/v2/auth?response_type=token&client_id=client-id&scope=openid%20user&state=thestate&redirect_uri=https://client.example.com/login/oauth2");
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,14 +19,18 @@ import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;
/**
@ -50,6 +54,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private Set<String> scopes;
private String state;
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private OAuth2AuthorizationRequest() {
}
@ -126,6 +131,20 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return this.additionalParameters;
}
/**
* Returns the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
* <p>
* <b>NOTE:</b> The {@code URI} string is encoded in the
* {@code application/x-www-form-urlencoded} MIME format.
*
* @since 5.1
* @return the {@code URI} string representation of the OAuth 2.0 Authorization Request
*/
public String getAuthorizationRequestUri() {
return this.authorizationRequestUri;
}
/**
* Returns a new {@link Builder}, initialized with the authorization code grant type.
*
@ -144,6 +163,26 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return new Builder(AuthorizationGrantType.IMPLICIT);
}
/**
* Returns a new {@link Builder}, initialized with the values
* from the provided {@code authorizationRequest}.
*
* @since 5.1
* @param authorizationRequest the authorization request used for initializing the {@link Builder}
* @return the {@link Builder}
*/
public static Builder from(OAuth2AuthorizationRequest authorizationRequest) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
return new Builder(authorizationRequest.getGrantType())
.authorizationUri(authorizationRequest.getAuthorizationUri())
.clientId(authorizationRequest.getClientId())
.redirectUri(authorizationRequest.getRedirectUri())
.scopes(authorizationRequest.getScopes())
.state(authorizationRequest.getState())
.additionalParameters(authorizationRequest.getAdditionalParameters());
}
/**
* A builder for {@link OAuth2AuthorizationRequest}.
*/
@ -156,6 +195,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private Set<String> scopes;
private String state;
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Builder(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
@ -247,6 +287,22 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return this;
}
/**
* Sets the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
* <p>
* <b>NOTE:</b> The {@code URI} string is <b>required</b> to be encoded in the
* {@code application/x-www-form-urlencoded} MIME format.
*
* @since 5.1
* @param authorizationRequestUri the {@code URI} string representation of the OAuth 2.0 Authorization Request
* @return the {@link Builder}
*/
public Builder authorizationRequestUri(String authorizationRequestUri) {
this.authorizationRequestUri = authorizationRequestUri;
return this;
}
/**
* Builds a new {@link OAuth2AuthorizationRequest}.
*
@ -272,7 +328,42 @@ public final class OAuth2AuthorizationRequest implements Serializable {
authorizationRequest.additionalParameters = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.additionalParameters) ?
Collections.emptyMap() : new LinkedHashMap<>(this.additionalParameters));
authorizationRequest.authorizationRequestUri =
StringUtils.hasText(this.authorizationRequestUri) ?
this.authorizationRequestUri : this.buildAuthorizationRequestUri();
return authorizationRequest;
}
private String buildAuthorizationRequestUri() {
Map<String, String> parameters = new LinkedHashMap<>();
parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);
if (!CollectionUtils.isEmpty(this.scopes)) {
parameters.put(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(this.scopes, " "));
}
if (this.state != null) {
parameters.put(OAuth2ParameterNames.STATE, this.state);
}
if (this.redirectUri != null) {
parameters.put(OAuth2ParameterNames.REDIRECT_URI, this.redirectUri);
}
if (!CollectionUtils.isEmpty(this.additionalParameters)) {
this.additionalParameters.entrySet().stream()
.filter(e -> !e.getKey().equals(OAuth2ParameterNames.REGISTRATION_ID))
.forEach(e -> parameters.put(e.getKey(), e.getValue().toString()));
}
try {
StringJoiner queryParams = new StringJoiner("&");
for (String paramName : parameters.keySet()) {
queryParams.add(paramName + "=" + URLEncoder.encode(parameters.get(paramName), "UTF-8"));
}
return this.authorizationUri + "?" + queryParams.toString();
} catch (UnsupportedEncodingException ex) {
throw new IllegalArgumentException("Unable to build authorization request uri: " + ex.getMessage(), ex);
}
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,9 +16,6 @@
package org.springframework.security.oauth2.core.endpoint;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import java.util.Arrays;
@ -27,8 +24,7 @@ import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.*;
/**
* Tests for {@link OAuth2AuthorizationRequest}.
@ -36,8 +32,6 @@ import static org.assertj.core.api.Assertions.assertThatCode;
* @author Luander Ribeiro
* @author Joe Grandja
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(OAuth2AuthorizationRequest.class)
public class OAuth2AuthorizationRequestTests {
private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize";
private static final String CLIENT_ID = "client-id";
@ -45,48 +39,96 @@ public class OAuth2AuthorizationRequestTests {
private static final Set<String> SCOPES = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
private static final String STATE = "state";
@Test(expected = IllegalArgumentException.class)
@Test
public void buildWhenAuthorizationUriIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() ->
OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(null)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.build();
.build()
).isInstanceOf(IllegalArgumentException.class);
}
@Test(expected = IllegalArgumentException.class)
@Test
public void buildWhenClientIdIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() ->
OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(null)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.build();
.build()
).isInstanceOf(IllegalArgumentException.class);
}
@Test(expected = IllegalArgumentException.class)
@Test
public void buildWhenRedirectUriIsNullForImplicitThenThrowIllegalArgumentException() {
assertThatThrownBy(() ->
OAuth2AuthorizationRequest.implicit()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(null)
.scopes(SCOPES)
.state(STATE)
.build();
.build()
).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void buildWhenRedirectUriIsNullForAuthorizationCodeThenDoesNotThrowAnyException() {
assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
assertThatCode(() ->
OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(null)
.scopes(SCOPES)
.state(STATE)
.build()).doesNotThrowAnyException();
.build())
.doesNotThrowAnyException();
}
@Test
public void buildWhenScopesIsNullThenDoesNotThrowAnyException() {
assertThatCode(() ->
OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(null)
.state(STATE)
.build())
.doesNotThrowAnyException();
}
@Test
public void buildWhenStateIsNullThenDoesNotThrowAnyException() {
assertThatCode(() ->
OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(null)
.build())
.doesNotThrowAnyException();
}
@Test
public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() {
assertThatCode(() ->
OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.additionalParameters(null)
.build())
.doesNotThrowAnyException();
}
@Test
@ -116,7 +158,60 @@ public class OAuth2AuthorizationRequestTests {
}
@Test
public void buildWhenAllAttributesProvidedThenAllAttributesAreSet() {
public void buildWhenAllValuesProvidedThenAllValuesAreSet() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.additionalParameters(additionalParameters)
.authorizationRequestUri(AUTHORIZATION_URI)
.build();
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getClientId()).isEqualTo(CLIENT_ID);
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(REDIRECT_URI);
assertThat(authorizationRequest.getScopes()).isEqualTo(SCOPES);
assertThat(authorizationRequest.getState()).isEqualTo(STATE);
assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters);
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
}
@Test
public void buildWhenScopesMultiThenSeparatedByEncodedSpace() {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.implicit()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.build();
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=token&client_id=client-id&scope=scope1+scope2&state=state&redirect_uri=http%3A%2F%2Fexample.com");
}
@Test
public void buildWhenAuthorizationRequestUriSetThenOverridesDefault() {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.authorizationRequestUri(AUTHORIZATION_URI)
.build();
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo(AUTHORIZATION_URI);
}
@Test
public void buildWhenAuthorizationRequestUriNotSetThenDefaultSet() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
@ -130,47 +225,69 @@ public class OAuth2AuthorizationRequestTests {
.additionalParameters(additionalParameters)
.build();
assertThat(authorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
assertThat(authorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(authorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
assertThat(authorizationRequest.getClientId()).isEqualTo(CLIENT_ID);
assertThat(authorizationRequest.getRedirectUri()).isEqualTo(REDIRECT_URI);
assertThat(authorizationRequest.getScopes()).isEqualTo(SCOPES);
assertThat(authorizationRequest.getState()).isEqualTo(STATE);
assertThat(authorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters);
assertThat(authorizationRequest.getAuthorizationRequestUri()).isNotNull();
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id&scope=scope1+scope2&state=state&redirect_uri=http%3A%2F%2Fexample.com&param1=value1&param2=value2");
}
@Test
public void buildWhenScopesIsNullThenDoesNotThrowAnyException() {
assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
public void buildWhenRequiredParametersSetThenAuthorizationRequestUriIncludesRequiredParametersOnly() {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(null)
.state(STATE)
.build()).doesNotThrowAnyException();
.build();
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id");
}
@Test
public void buildWhenStateIsNullThenDoesNotThrowAnyException() {
assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(null)
.build()).doesNotThrowAnyException();
}
public void buildWhenAuthorizationRequestIncludesRegistrationIdParameterThenAuthorizationRequestUriDoesNotIncludeRegistrationIdParameter() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put(OAuth2ParameterNames.REGISTRATION_ID, "registration1");
@Test
public void buildWhenAdditionalParametersIsNullThenDoesNotThrowAnyException() {
assertThatCode(() -> OAuth2AuthorizationRequest.authorizationCode()
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.additionalParameters(null)
.build()).doesNotThrowAnyException();
.additionalParameters(additionalParameters)
.build();
assertThat(authorizationRequest.getAuthorizationRequestUri()).isEqualTo("https://provider.com/oauth2/authorize?response_type=code&client_id=client-id&scope=scope1+scope2&state=state&redirect_uri=http%3A%2F%2Fexample.com&param1=value1");
}
@Test
public void fromWhenAuthorizationRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> OAuth2AuthorizationRequest.from(null)).isInstanceOf(IllegalArgumentException.class);
}
@Test
public void fromWhenAuthorizationRequestProvidedThenValuesAreCopied() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.additionalParameters(additionalParameters)
.build();
OAuth2AuthorizationRequest authorizationRequestCopy =
OAuth2AuthorizationRequest.from(authorizationRequest).build();
assertThat(authorizationRequestCopy.getAuthorizationUri()).isEqualTo(authorizationRequest.getAuthorizationUri());
assertThat(authorizationRequestCopy.getGrantType()).isEqualTo(authorizationRequest.getGrantType());
assertThat(authorizationRequestCopy.getResponseType()).isEqualTo(authorizationRequest.getResponseType());
assertThat(authorizationRequestCopy.getClientId()).isEqualTo(authorizationRequest.getClientId());
assertThat(authorizationRequestCopy.getRedirectUri()).isEqualTo(authorizationRequest.getRedirectUri());
assertThat(authorizationRequestCopy.getScopes()).isEqualTo(authorizationRequest.getScopes());
assertThat(authorizationRequestCopy.getState()).isEqualTo(authorizationRequest.getState());
assertThat(authorizationRequestCopy.getAdditionalParameters()).isEqualTo(authorizationRequest.getAdditionalParameters());
assertThat(authorizationRequestCopy.getAuthorizationRequestUri()).isEqualTo(authorizationRequest.getAuthorizationRequestUri());
}
}

View File

@ -88,7 +88,7 @@ public class OAuth2AuthorizationCodeGrantApplicationTests {
MvcResult mvcResult = this.mockMvc.perform(get("/repos").with(user("user")))
.andExpect(status().is3xxRedirection())
.andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://github.com/login/oauth/authorize\\?response_type=code&client_id=your-app-client-id&scope=public_repo&state=.{15,}&redirect_uri=http://localhost/github-repos");
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://github.com/login/oauth/authorize\\?response_type=code&client_id=your-app-client-id&scope=public_repo&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fgithub-repos");
}
@Test