Add OAuth2AuthorizationRequestResolver.resolve(HttpServletRequest,String)

Previously there was a tangle between
DefaultOAuth2AuthorizationRequestResolver and
OAuth2AuthorizationRequestRedirectFilter with
AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME

This commit adds a new method that can be used for resolving the
OAuth2AuthorizationRequest when the client registration id is known.

Issue: gh-4911
This commit is contained in:
Rob Winch 2018-08-16 20:41:13 -05:00
parent 06df562d61
commit 938dbbf424
7 changed files with 87 additions and 100 deletions

View File

@ -192,21 +192,16 @@ public class OAuth2ClientConfigurerTests {
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
// Override default resolver
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
authorizationRequestResolver = request -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("param1", "value1");
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
when(authorizationRequestResolver.resolve(any())).thenAnswer(invocation -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
this.spring.register(OAuth2ClientConfig.class).autowire();
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
.andExpect(status().is3xxRedirection())
.andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1&param1=value1");
verify(authorizationRequestResolver).resolve(any());
}
@EnableWebSecurity

View File

@ -44,7 +44,6 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
@ -78,6 +77,9 @@ import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link OAuth2LoginConfigurer}.
@ -236,6 +238,15 @@ public class OAuth2LoginConfigurerTests {
@Test
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
OAuth2AuthorizationRequestResolver resolver = this.context.getBean(
OAuth2LoginConfigCustomAuthorizationRequestResolver.class).resolver;
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://accounts.google.com/authorize")
.clientId("client-id")
.state("adsfa")
.authorizationRequestUri("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1")
.build();
when(resolver.resolve(any())).thenReturn(result);
String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri);
@ -243,7 +254,7 @@ public class OAuth2LoginConfigurerTests {
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
assertThat(this.response.getRedirectedUrl()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
}
// gh-5347
@ -492,28 +503,17 @@ public class OAuth2LoginConfigurerTests {
private ClientRegistrationRepository clientRegistrationRepository =
new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION);
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
@Override
protected void configure(HttpSecurity http) throws Exception {
http
.oauth2Login()
.clientRegistrationRepository(this.clientRegistrationRepository)
.authorizationEndpoint()
.authorizationRequestResolver(this.getAuthorizationRequestResolver());
.authorizationRequestResolver(this.resolver);
super.configure(http);
}
private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
return request -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("custom-param1", "custom-value1");
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
}
}
@EnableWebSecurity

View File

@ -17,7 +17,6 @@ package org.springframework.security.oauth2.client.web;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -33,8 +32,6 @@ import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
/**
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
* resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
@ -45,6 +42,7 @@ import static org.springframework.security.oauth2.client.web.OAuth2Authorization
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
*
* @author Joe Grandja
* @author Rob Winch
* @since 5.1
* @see OAuth2AuthorizationRequestResolver
* @see OAuth2AuthorizationRequestRedirectFilter
@ -73,6 +71,28 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
String registrationId = this.resolveRegistrationId(request);
String redirectUriAction = getAction(request, "login");
return resolve(request, registrationId, redirectUriAction);
}
@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId) {
if (registrationId == null) {
return null;
}
String redirectUriAction = getAction(request, "authorize");
return resolve(request, registrationId, redirectUriAction);
}
private String getAction(HttpServletRequest request, String defaultAction) {
String action = request.getParameter("action");
if (action == null) {
return defaultAction;
}
return action;
}
private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId, String redirectUriAction) {
if (registrationId == null) {
return null;
}
@ -93,7 +113,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
}
String redirectUriAction = this.resolveRedirectUriAction(request, clientRegistration);
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
Map<String, Object> additionalParameters = new HashMap<>();
@ -112,13 +131,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
}
private String resolveRegistrationId(HttpServletRequest request) {
// Check for ClientAuthorizationRequiredException which may have been set
// in the request by OAuth2AuthorizationRequestRedirectFilter
ClientAuthorizationRequiredException authzEx =
(ClientAuthorizationRequiredException) request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
if (authzEx != null) {
return authzEx.getClientRegistrationId();
}
if (this.authorizationRequestMatcher.matches(request)) {
return this.authorizationRequestMatcher
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
@ -126,29 +138,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
return null;
}
private String resolveRedirectUriAction(HttpServletRequest request, ClientRegistration clientRegistration) {
String action = null;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
String loginAction = "login";
String authorizeAction = "authorize";
String actionParameter = request.getParameter("action");
if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
// Check for ClientAuthorizationRequiredException which may have been set
// in the request by OAuth2AuthorizationRequestRedirectFilter
action = authorizeAction;
} else if (actionParameter == null) {
action = loginAction; // Default
} else {
if (actionParameter.equalsIgnoreCase(loginAction)) {
action = loginAction;
} else {
action = authorizeAction;
}
}
}
return action;
}
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
// Supported URI variables -> baseUrl, action, registrationId
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"

View File

@ -79,8 +79,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
* The default base {@code URI} used for authorization requests.
*/
public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
@ -169,8 +167,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
.getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain);
if (authzEx != null) {
try {
request.setAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME, authzEx);
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request, authzEx.getClientRegistrationId());
if (authorizationRequest == null) {
throw authzEx;
}
@ -178,8 +175,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
this.requestCache.saveRequest(request, response);
} catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed);
} finally {
request.removeAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
}
return;
}

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletRequest;
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
*
* @author Joe Grandja
* @author Rob Winch
* @since 5.1
* @see OAuth2AuthorizationRequest
* @see OAuth2AuthorizationRequestRedirectFilter
@ -40,4 +41,14 @@ public interface OAuth2AuthorizationRequestResolver {
*/
OAuth2AuthorizationRequest resolve(HttpServletRequest request);
/**
* Returns the {@link OAuth2AuthorizationRequest} resolved from
* the provided {@code HttpServletRequest} or {@code null} if not available.
*
* @param request the {@code HttpServletRequest}
* @param clientRegistrationId the clientRegistrationId to use
* @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId);
}

View File

@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.web;
import org.junit.Before;
import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
@ -28,7 +27,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.entry;
/**
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@ -139,11 +140,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setAttribute(
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
@ -213,11 +211,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
request.setAttribute(
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
}

View File

@ -37,13 +37,13 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Constructor;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
/**
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
@ -274,7 +274,6 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME)).isNull();
}
@Test
@ -288,7 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
OAuth2AuthorizationRequestResolver resolver = req -> null;
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain);
@ -315,14 +314,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
OAuth2AuthorizationRequestResolver resolver = req -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("idp", req.getParameter("idp"));
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
.from(defaultAuthorizationRequestResolver.resolve(request))
.additionalParameters(
Collections.singletonMap("idp", request.getParameter("idp")))
.build();
when(resolver.resolve(any())).thenReturn(result);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain);
@ -347,19 +345,23 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
OAuth2AuthorizationRequestResolver resolver = req -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put(loginHintParamName, req.getParameter(loginHintParamName));
String customAuthorizationRequestUri = UriComponentsBuilder
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
.build(true).toUriString();
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.authorizationRequestUri(customAuthorizationRequestUri)
.build();
};
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName));
String customAuthorizationRequestUri = UriComponentsBuilder
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
.build(true).toUriString();
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
.from(defaultAuthorizationRequestResolver.resolve(request))
.additionalParameters(
Collections.singletonMap("idp", request.getParameter("idp")))
.authorizationRequestUri(customAuthorizationRequestUri)
.build();
when(resolver.resolve(any())).thenReturn(result);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain);