Add OAuth2AuthorizationRequestResolver.resolve(HttpServletRequest,String)
Previously there was a tangle between DefaultOAuth2AuthorizationRequestResolver and OAuth2AuthorizationRequestRedirectFilter with AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME This commit adds a new method that can be used for resolving the OAuth2AuthorizationRequest when the client registration id is known. Issue: gh-4911
This commit is contained in:
parent
06df562d61
commit
938dbbf424
|
@ -192,21 +192,16 @@ public class OAuth2ClientConfigurerTests {
|
||||||
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
|
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
|
||||||
// Override default resolver
|
// Override default resolver
|
||||||
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
|
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
|
||||||
authorizationRequestResolver = request -> {
|
authorizationRequestResolver = mock(OAuth2AuthorizationRequestResolver.class);
|
||||||
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
|
when(authorizationRequestResolver.resolve(any())).thenAnswer(invocation -> defaultAuthorizationRequestResolver.resolve(invocation.getArgument(0)));
|
||||||
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
|
||||||
additionalParameters.put("param1", "value1");
|
|
||||||
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
|
|
||||||
.additionalParameters(additionalParameters)
|
|
||||||
.build();
|
|
||||||
};
|
|
||||||
|
|
||||||
this.spring.register(OAuth2ClientConfig.class).autowire();
|
this.spring.register(OAuth2ClientConfig.class).autowire();
|
||||||
|
|
||||||
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
|
this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
|
||||||
.andExpect(status().is3xxRedirection())
|
.andExpect(status().is3xxRedirection())
|
||||||
.andReturn();
|
.andReturn();
|
||||||
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1¶m1=value1");
|
|
||||||
|
verify(authorizationRequestResolver).resolve(any());
|
||||||
}
|
}
|
||||||
|
|
||||||
@EnableWebSecurity
|
@EnableWebSecurity
|
||||||
|
|
|
@ -44,7 +44,6 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
|
||||||
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
|
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
|
||||||
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
|
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
|
||||||
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
|
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
|
||||||
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
|
|
||||||
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
|
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
|
||||||
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
|
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
|
||||||
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
|
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
|
||||||
|
@ -78,6 +77,9 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link OAuth2LoginConfigurer}.
|
* Tests for {@link OAuth2LoginConfigurer}.
|
||||||
|
@ -236,6 +238,15 @@ public class OAuth2LoginConfigurerTests {
|
||||||
@Test
|
@Test
|
||||||
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
|
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
|
||||||
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
|
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
|
||||||
|
OAuth2AuthorizationRequestResolver resolver = this.context.getBean(
|
||||||
|
OAuth2LoginConfigCustomAuthorizationRequestResolver.class).resolver;
|
||||||
|
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest.authorizationCode()
|
||||||
|
.authorizationUri("https://accounts.google.com/authorize")
|
||||||
|
.clientId("client-id")
|
||||||
|
.state("adsfa")
|
||||||
|
.authorizationRequestUri("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1")
|
||||||
|
.build();
|
||||||
|
when(resolver.resolve(any())).thenReturn(result);
|
||||||
|
|
||||||
String requestUri = "/oauth2/authorization/google";
|
String requestUri = "/oauth2/authorization/google";
|
||||||
this.request = new MockHttpServletRequest("GET", requestUri);
|
this.request = new MockHttpServletRequest("GET", requestUri);
|
||||||
|
@ -243,7 +254,7 @@ public class OAuth2LoginConfigurerTests {
|
||||||
|
|
||||||
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
|
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
|
||||||
|
|
||||||
assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
|
assertThat(this.response.getRedirectedUrl()).isEqualTo("https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
|
||||||
}
|
}
|
||||||
|
|
||||||
// gh-5347
|
// gh-5347
|
||||||
|
@ -492,28 +503,17 @@ public class OAuth2LoginConfigurerTests {
|
||||||
private ClientRegistrationRepository clientRegistrationRepository =
|
private ClientRegistrationRepository clientRegistrationRepository =
|
||||||
new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION);
|
new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION);
|
||||||
|
|
||||||
|
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void configure(HttpSecurity http) throws Exception {
|
protected void configure(HttpSecurity http) throws Exception {
|
||||||
http
|
http
|
||||||
.oauth2Login()
|
.oauth2Login()
|
||||||
.clientRegistrationRepository(this.clientRegistrationRepository)
|
.clientRegistrationRepository(this.clientRegistrationRepository)
|
||||||
.authorizationEndpoint()
|
.authorizationEndpoint()
|
||||||
.authorizationRequestResolver(this.getAuthorizationRequestResolver());
|
.authorizationRequestResolver(this.resolver);
|
||||||
super.configure(http);
|
super.configure(http);
|
||||||
}
|
}
|
||||||
|
|
||||||
private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
|
|
||||||
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
|
|
||||||
new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
|
|
||||||
return request -> {
|
|
||||||
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
|
|
||||||
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
|
||||||
additionalParameters.put("custom-param1", "custom-value1");
|
|
||||||
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
|
|
||||||
.additionalParameters(additionalParameters)
|
|
||||||
.build();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@EnableWebSecurity
|
@EnableWebSecurity
|
||||||
|
|
|
@ -17,7 +17,6 @@ package org.springframework.security.oauth2.client.web;
|
||||||
|
|
||||||
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
|
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
|
||||||
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
import org.springframework.security.crypto.keygen.StringKeyGenerator;
|
||||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
|
@ -33,8 +32,6 @@ import java.util.Base64;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
|
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
|
||||||
* resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
|
* resolve an {@link OAuth2AuthorizationRequest} from the provided {@code HttpServletRequest}
|
||||||
|
@ -45,6 +42,7 @@ import static org.springframework.security.oauth2.client.web.OAuth2Authorization
|
||||||
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
|
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
|
||||||
*
|
*
|
||||||
* @author Joe Grandja
|
* @author Joe Grandja
|
||||||
|
* @author Rob Winch
|
||||||
* @since 5.1
|
* @since 5.1
|
||||||
* @see OAuth2AuthorizationRequestResolver
|
* @see OAuth2AuthorizationRequestResolver
|
||||||
* @see OAuth2AuthorizationRequestRedirectFilter
|
* @see OAuth2AuthorizationRequestRedirectFilter
|
||||||
|
@ -73,6 +71,28 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
||||||
@Override
|
@Override
|
||||||
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
|
public OAuth2AuthorizationRequest resolve(HttpServletRequest request) {
|
||||||
String registrationId = this.resolveRegistrationId(request);
|
String registrationId = this.resolveRegistrationId(request);
|
||||||
|
String redirectUriAction = getAction(request, "login");
|
||||||
|
return resolve(request, registrationId, redirectUriAction);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId) {
|
||||||
|
if (registrationId == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String redirectUriAction = getAction(request, "authorize");
|
||||||
|
return resolve(request, registrationId, redirectUriAction);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getAction(HttpServletRequest request, String defaultAction) {
|
||||||
|
String action = request.getParameter("action");
|
||||||
|
if (action == null) {
|
||||||
|
return defaultAction;
|
||||||
|
}
|
||||||
|
return action;
|
||||||
|
}
|
||||||
|
|
||||||
|
private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId, String redirectUriAction) {
|
||||||
if (registrationId == null) {
|
if (registrationId == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -93,7 +113,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
||||||
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
|
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
|
||||||
}
|
}
|
||||||
|
|
||||||
String redirectUriAction = this.resolveRedirectUriAction(request, clientRegistration);
|
|
||||||
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
|
String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
|
||||||
|
|
||||||
Map<String, Object> additionalParameters = new HashMap<>();
|
Map<String, Object> additionalParameters = new HashMap<>();
|
||||||
|
@ -112,13 +131,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
||||||
}
|
}
|
||||||
|
|
||||||
private String resolveRegistrationId(HttpServletRequest request) {
|
private String resolveRegistrationId(HttpServletRequest request) {
|
||||||
// Check for ClientAuthorizationRequiredException which may have been set
|
|
||||||
// in the request by OAuth2AuthorizationRequestRedirectFilter
|
|
||||||
ClientAuthorizationRequiredException authzEx =
|
|
||||||
(ClientAuthorizationRequiredException) request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
|
|
||||||
if (authzEx != null) {
|
|
||||||
return authzEx.getClientRegistrationId();
|
|
||||||
}
|
|
||||||
if (this.authorizationRequestMatcher.matches(request)) {
|
if (this.authorizationRequestMatcher.matches(request)) {
|
||||||
return this.authorizationRequestMatcher
|
return this.authorizationRequestMatcher
|
||||||
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
|
.extractUriTemplateVariables(request).get(REGISTRATION_ID_URI_VARIABLE_NAME);
|
||||||
|
@ -126,29 +138,6 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
private String resolveRedirectUriAction(HttpServletRequest request, ClientRegistration clientRegistration) {
|
|
||||||
String action = null;
|
|
||||||
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
|
|
||||||
String loginAction = "login";
|
|
||||||
String authorizeAction = "authorize";
|
|
||||||
String actionParameter = request.getParameter("action");
|
|
||||||
if (request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME) != null) {
|
|
||||||
// Check for ClientAuthorizationRequiredException which may have been set
|
|
||||||
// in the request by OAuth2AuthorizationRequestRedirectFilter
|
|
||||||
action = authorizeAction;
|
|
||||||
} else if (actionParameter == null) {
|
|
||||||
action = loginAction; // Default
|
|
||||||
} else {
|
|
||||||
if (actionParameter.equalsIgnoreCase(loginAction)) {
|
|
||||||
action = loginAction;
|
|
||||||
} else {
|
|
||||||
action = authorizeAction;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return action;
|
|
||||||
}
|
|
||||||
|
|
||||||
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
|
private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
|
||||||
// Supported URI variables -> baseUrl, action, registrationId
|
// Supported URI variables -> baseUrl, action, registrationId
|
||||||
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
|
// Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
|
||||||
|
|
|
@ -79,8 +79,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||||
* The default base {@code URI} used for authorization requests.
|
* The default base {@code URI} used for authorization requests.
|
||||||
*/
|
*/
|
||||||
public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
|
public static final String DEFAULT_AUTHORIZATION_REQUEST_BASE_URI = "/oauth2/authorization";
|
||||||
static final String AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME =
|
|
||||||
ClientAuthorizationRequiredException.class.getName() + ".AUTHORIZATION_REQUIRED_EXCEPTION";
|
|
||||||
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
|
private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
|
||||||
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
|
private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy();
|
||||||
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
|
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
|
||||||
|
@ -169,8 +167,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||||
.getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain);
|
.getFirstThrowableOfType(ClientAuthorizationRequiredException.class, causeChain);
|
||||||
if (authzEx != null) {
|
if (authzEx != null) {
|
||||||
try {
|
try {
|
||||||
request.setAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME, authzEx);
|
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request, authzEx.getClientRegistrationId());
|
||||||
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestResolver.resolve(request);
|
|
||||||
if (authorizationRequest == null) {
|
if (authorizationRequest == null) {
|
||||||
throw authzEx;
|
throw authzEx;
|
||||||
}
|
}
|
||||||
|
@ -178,8 +175,6 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
||||||
this.requestCache.saveRequest(request, response);
|
this.requestCache.saveRequest(request, response);
|
||||||
} catch (Exception failed) {
|
} catch (Exception failed) {
|
||||||
this.unsuccessfulRedirectForAuthorization(request, response, failed);
|
this.unsuccessfulRedirectForAuthorization(request, response, failed);
|
||||||
} finally {
|
|
||||||
request.removeAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME);
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletRequest;
|
||||||
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
|
* Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for resolving Authorization Requests.
|
||||||
*
|
*
|
||||||
* @author Joe Grandja
|
* @author Joe Grandja
|
||||||
|
* @author Rob Winch
|
||||||
* @since 5.1
|
* @since 5.1
|
||||||
* @see OAuth2AuthorizationRequest
|
* @see OAuth2AuthorizationRequest
|
||||||
* @see OAuth2AuthorizationRequestRedirectFilter
|
* @see OAuth2AuthorizationRequestRedirectFilter
|
||||||
|
@ -40,4 +41,14 @@ public interface OAuth2AuthorizationRequestResolver {
|
||||||
*/
|
*/
|
||||||
OAuth2AuthorizationRequest resolve(HttpServletRequest request);
|
OAuth2AuthorizationRequest resolve(HttpServletRequest request);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link OAuth2AuthorizationRequest} resolved from
|
||||||
|
* the provided {@code HttpServletRequest} or {@code null} if not available.
|
||||||
|
*
|
||||||
|
* @param request the {@code HttpServletRequest}
|
||||||
|
* @param clientRegistrationId the clientRegistrationId to use
|
||||||
|
* @return the resolved {@link OAuth2AuthorizationRequest} or {@code null} if not available
|
||||||
|
*/
|
||||||
|
OAuth2AuthorizationRequest resolve(HttpServletRequest request, String clientRegistrationId);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ package org.springframework.security.oauth2.client.web;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.springframework.mock.web.MockHttpServletRequest;
|
import org.springframework.mock.web.MockHttpServletRequest;
|
||||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
||||||
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
|
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
|
||||||
|
@ -28,7 +27,9 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.*;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
|
import static org.assertj.core.api.Assertions.entry;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
|
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
|
||||||
|
@ -139,11 +140,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
|
||||||
String requestUri = "/path";
|
String requestUri = "/path";
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
request.setServletPath(requestUri);
|
request.setServletPath(requestUri);
|
||||||
request.setAttribute(
|
|
||||||
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
|
|
||||||
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
|
|
||||||
|
|
||||||
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
|
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
|
||||||
assertThat(authorizationRequest).isNotNull();
|
assertThat(authorizationRequest).isNotNull();
|
||||||
assertThat(authorizationRequest.getAdditionalParameters())
|
assertThat(authorizationRequest.getAdditionalParameters())
|
||||||
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
|
.containsExactly(entry(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()));
|
||||||
|
@ -213,11 +211,8 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
|
||||||
String requestUri = "/path";
|
String requestUri = "/path";
|
||||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||||
request.setServletPath(requestUri);
|
request.setServletPath(requestUri);
|
||||||
request.setAttribute(
|
|
||||||
OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME,
|
|
||||||
new ClientAuthorizationRequiredException(clientRegistration.getRegistrationId()));
|
|
||||||
|
|
||||||
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
|
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request, clientRegistration.getRegistrationId());
|
||||||
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
|
assertThat(authorizationRequest.getAuthorizationRequestUri()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,13 +37,13 @@ import javax.servlet.ServletResponse;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import java.lang.reflect.Constructor;
|
import java.lang.reflect.Constructor;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.*;
|
||||||
import static org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter.AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
|
* Tests for {@link OAuth2AuthorizationRequestRedirectFilter}.
|
||||||
|
@ -274,7 +274,6 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
||||||
|
|
||||||
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
|
assertThat(response.getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fauthorize%2Foauth2%2Fcode%2Fregistration-1");
|
||||||
verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
verify(this.requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||||
assertThat(request.getAttribute(AUTHORIZATION_REQUIRED_EXCEPTION_ATTR_NAME)).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -288,7 +287,7 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
||||||
doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
|
doThrow(new ClientAuthorizationRequiredException(this.registration1.getRegistrationId()))
|
||||||
.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
.when(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
||||||
|
|
||||||
OAuth2AuthorizationRequestResolver resolver = req -> null;
|
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
||||||
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
@ -315,14 +314,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
||||||
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
|
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
|
||||||
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
|
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
|
||||||
|
|
||||||
OAuth2AuthorizationRequestResolver resolver = req -> {
|
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
||||||
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
|
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
|
||||||
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
.from(defaultAuthorizationRequestResolver.resolve(request))
|
||||||
additionalParameters.put("idp", req.getParameter("idp"));
|
.additionalParameters(
|
||||||
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
|
Collections.singletonMap("idp", request.getParameter("idp")))
|
||||||
.additionalParameters(additionalParameters)
|
.build();
|
||||||
.build();
|
when(resolver.resolve(any())).thenReturn(result);
|
||||||
};
|
|
||||||
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
@ -347,19 +345,23 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
|
||||||
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
|
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
|
||||||
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
|
this.clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
|
||||||
|
|
||||||
OAuth2AuthorizationRequestResolver resolver = req -> {
|
OAuth2AuthorizationRequestResolver resolver = mock(OAuth2AuthorizationRequestResolver.class);
|
||||||
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(req);
|
|
||||||
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
|
||||||
additionalParameters.put(loginHintParamName, req.getParameter(loginHintParamName));
|
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
|
||||||
String customAuthorizationRequestUri = UriComponentsBuilder
|
additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName));
|
||||||
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
|
String customAuthorizationRequestUri = UriComponentsBuilder
|
||||||
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
|
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
|
||||||
.build(true).toUriString();
|
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
|
||||||
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
|
.build(true).toUriString();
|
||||||
.additionalParameters(additionalParameters)
|
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
|
||||||
.authorizationRequestUri(customAuthorizationRequestUri)
|
.from(defaultAuthorizationRequestResolver.resolve(request))
|
||||||
.build();
|
.additionalParameters(
|
||||||
};
|
Collections.singletonMap("idp", request.getParameter("idp")))
|
||||||
|
.authorizationRequestUri(customAuthorizationRequestUri)
|
||||||
|
.build();
|
||||||
|
when(resolver.resolve(any())).thenReturn(result);
|
||||||
|
|
||||||
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
|
||||||
|
|
||||||
filter.doFilter(request, response, filterChain);
|
filter.doFilter(request, response, filterChain);
|
||||||
|
|
Loading…
Reference in New Issue