Add OAuth2AuthorizationRequestResolver.resolve(HttpServletRequest,String)

Previously there was a tangle between
DefaultOAuth2AuthorizationRequestResolver and
OAuth2AuthorizationRequestRedirectFilter with
AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME

This commit adds a new method that can be used for resolving the
OAuth2AuthorizationRequest when the client registration id is known.

Issue: gh-4911
This commit is contained in:
Rob Winch 2018-08-16 20:41:13 -05:00
parent 06df562d61
commit 938dbbf424
7 changed files with 87 additions and 100 deletions

View File

@ -192,21 +192,16 @@ public class OAuth2ClientConfigurerTests {
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception { public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
// Override default resolver // Override default resolver
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver; OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
authorizationRequestResolver = request -> { authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request); when(authorizationRequestResolver.resolve(any())).thenAnswer(invocation -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("param1", "value1");
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
this.spring.register(OAuth2ClientConfig.class).autowire(); this.spring.register(OAuth2ClientConfig.class).autowire();
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1")) this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
.andExpect(status().is3xxRedirection()) .andExpect(status().is3xxRedirection())
.andReturn(); .andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1&param1=value1");
verify(authorizationRequestResolver).resolve(any());
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -44,7 +44,6 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter; import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
@ -78,6 +77,9 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/** /**
* Tests for {@link OAuth2LoginConfigurer}. * Tests for {@link OAuth2LoginConfigurer}.
@ -236,6 +238,15 @@ public class OAuth2LoginConfigurerTests {
@Test @Test
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception { public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class); loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
OAuth2AuthorizationRequestResolver resolver = this.context.getBean(
OAuth2LoginConfigCustomAuthorizationRequestResolver.class).resolver;
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://accounts.google.com/authorize")
.clientId("client-id")
.state("adsfa")
.authorizationRequestUri("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1")
.build();
when(resolver.resolve(any())).thenReturn(result);
String requestUri = "/oauth2/authorization/google"; String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri); this.request = new MockHttpServletRequest("GET", requestUri);
@ -243,7 +254,7 @@ public class OAuth2LoginConfigurerTests {
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); assertThat(this.response.getRedirectedUrl()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
} }
// gh-5347 // gh-5347
@ -492,28 +503,17 @@ public class OAuth2LoginConfigurerTests {
private ClientRegistrationRepository clientRegistrationRepository = private ClientRegistrationRepository clientRegistrationRepository =
new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION); new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION);
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
@Override @Override
protected void configure(HttpSecurity http) throws Exception { protected void configure(HttpSecurity http) throws Exception {
http http
.oauth2Login() .oauth2Login()
.clientRegistrationRepository(this.clientRegistrationRepository) .clientRegistrationRepository(this.clientRegistrationRepository)
.authorizationEndpoint() .authorizationEndpoint()
.authorizationRequestResolver(this.getAuthorizationRequestResolver()); .authorizationRequestResolver(this.resolver);
super.configure(http); super.configure(http);
} }
private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
return request -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("custom-param1", "custom-value1");
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
}
} }
@EnableWebSecurity @EnableWebSecurity

View File

@ -17,7 +17,6 @@ package org.springframework.security.oauth2.client.web;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@ -33,8 +32,6 @@ import java.util.Base64;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
/** /**
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to * An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
* resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest} * resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
@ -45,6 +42,7 @@ import static org.springframework.security.oauth2.client.web.OAuth2Authorization
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}. * via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
* *
* @author Joe Grandja * @author Joe Grandja
* @author Rob Winch
* @since 5.1 * @since 5.1
* @see OAuth2AuthorizationRequestResolver * @see OAuth2AuthorizationRequestResolver
* @see OAuth2AuthorizationRequestRedirectFilter * @see OAuth2AuthorizationRequestRedirectFilter
@ -73,6 +71,28 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
@Override @Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) { public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
String registrationId = this.resolveRegistrationId(request); String registrationId = this.resolveRegistrationId(request);
String redirectUriAction = getAction(request, "login");
return resolve(request, registrationId, redirectUriAction);
}
@Override
public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId) {
if (registrationId == null) {
return null;
}
String redirectUriAction = getAction(request, "authorize");
return resolve(request, registrationId, redirectUriAction);
}
private String getAction(HttpServletRequest request, String defaultAction) {
String action = request.getParameter("action");
if (action == null) {
return defaultAction;
}
return action;
}
private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId, String redirectUriAction) {
if (registrationId == null) { if (registrationId == null) {
return null; return null;
} }
@ -93,7 +113,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
") for Client Registration with Id: " + clientRegistration.getRegistrationId()); ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
} }
String redirectUriAction = this.resolveRedirectUriAction(request, clientRegistration);
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction); String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
Map<String, Object> additionalParameters = new HashMap<>(); Map<String, Object> additionalParameters = new HashMap<>();
@ -112,13 +131,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
} }
private String resolveRegistrationId(HttpServletRequest request) { private String resolveRegistrationId(HttpServletRequest request) {
// Check for ClientAuthorizationRequiredException which may have been set
// in the request by OAuth2AuthorizationRequestRedirectFilter
ClientAuthorizationRequiredException authzEx =
(ClientAuthorizationRequiredException) request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
if (authzEx != null) {
return authzEx.getClientRegistrationId();
}
if (this.authorizationRequestMatcher.matches(request)) { if (this.authorizationRequestMatcher.matches(request)) {
return this.authorizationRequestMatcher return this.authorizationRequestMatcher
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME); .extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
@ -126,29 +138,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
return null; return null;
} }
private String resolveRedirectUriAction(HttpServletRequest request, ClientRegistration clientRegistration) {
String action = null;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
String loginAction = "login";
String authorizeAction = "authorize";
String actionParameter = request.getParameter("action");
if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
// Check for ClientAuthorizationRequiredException which may have been set
// in the request by OAuth2AuthorizationRequestRedirectFilter
action = authorizeAction;
} else if (actionParameter == null) {
action = loginAction; // Default
} else {
if (actionParameter.equalsIgnoreCase(loginAction)) {
action = loginAction;
} else {
action = authorizeAction;
}
}
}
return action;
}
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) { private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
// Supported URI variables -> baseUrl, action, registrationId // Supported URI variables -> baseUrl, action, registrationId
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}" // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"

View File

@ -79,8 +79,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
* The default base {@code URI} used for authorization requests. * The default base {@code URI} used for authorization requests.
*/ */
public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization"; public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
private OAuth2AuthorizationRequestResolver authorizationRequestResolver; private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
@ -169,8 +167,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
.getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain); .getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain);
if (authzEx != null) { if (authzEx != null) {
try { try {
request.setAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME, authzEx); OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request, authzEx.getClientRegistrationId());
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
if (authorizationRequest == null) { if (authorizationRequest == null) {
throw authzEx; throw authzEx;
} }
@ -178,8 +175,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
this.requestCache.saveRequest(request, response); this.requestCache.saveRequest(request, response);
} catch (Exception failed) { } catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed); this.unsuccessfulRedirectForAuthorization(request, response, failed);
} finally {
request.removeAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
} }
return; return;
} }

View File

@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletRequest;
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests. * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
* *
* @author Joe Grandja * @author Joe Grandja
* @author Rob Winch
* @since 5.1 * @since 5.1
* @see OAuth2AuthorizationRequest * @see OAuth2AuthorizationRequest
* @see OAuth2AuthorizationRequestRedirectFilter * @see OAuth2AuthorizationRequestRedirectFilter
@ -40,4 +41,14 @@ public interface OAuth2AuthorizationRequestResolver {
*/ */
OAuth2AuthorizationRequest resolve(HttpServletRequest request); OAuth2AuthorizationRequest resolve(HttpServletRequest request);
/**
* Returns the {@link OAuth2AuthorizationRequest} resolved from
* the provided {@code HttpServletRequest} or {@code null} if not available.
*
* @param request the {@code HttpServletRequest}
* @param clientRegistrationId the clientRegistrationId to use
* @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available
*/
OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId);
} }

View File

@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.web;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
@ -28,7 +27,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.entry;
/** /**
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}. * Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@ -139,11 +140,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri); request.setServletPath(requestUri);
request.setAttribute(
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
assertThat(authorizationRequest).isNotNull(); assertThat(authorizationRequest).isNotNull();
assertThat(authorizationRequest.getAdditionalParameters()) assertThat(authorizationRequest.getAdditionalParameters())
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId())); .containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
@ -213,11 +211,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
String requestUri = "/path"; String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri); request.setServletPath(requestUri);
request.setAttribute(
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1"); assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
} }

View File

@ -37,13 +37,13 @@ import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
/** /**
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}. * Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
@ -274,7 +274,6 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1"); assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
assertThat(request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME)).isNull();
} }
@Test @Test
@ -288,7 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId())) doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); .when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
OAuth2AuthorizationRequestResolver resolver = req -> null; OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
@ -315,14 +314,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
OAuth2AuthorizationRequestResolver resolver = req -> { OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req); OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters()); .from(defaultAuthorizationRequestResolver.resolve(request))
additionalParameters.put("idp", req.getParameter("idp")); .additionalParameters(
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest) Collections.singletonMap("idp", request.getParameter("idp")))
.additionalParameters(additionalParameters) .build();
.build(); when(resolver.resolve(any())).thenReturn(result);
};
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);
@ -347,19 +345,23 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver( OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
OAuth2AuthorizationRequestResolver resolver = req -> { OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters()); OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
additionalParameters.put(loginHintParamName, req.getParameter(loginHintParamName)); Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
String customAuthorizationRequestUri = UriComponentsBuilder additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName));
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri()) String customAuthorizationRequestUri = UriComponentsBuilder
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName)) .fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
.build(true).toUriString(); .queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest) .build(true).toUriString();
.additionalParameters(additionalParameters) OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
.authorizationRequestUri(customAuthorizationRequestUri) .from(defaultAuthorizationRequestResolver.resolve(request))
.build(); .additionalParameters(
}; Collections.singletonMap("idp", request.getParameter("idp")))
.authorizationRequestUri(customAuthorizationRequestUri)
.build();
when(resolver.resolve(any())).thenReturn(result);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);