Allow configuring a custom OAuth2AuthorizationRequestResolver

Fixes gh-5521
This commit is contained in:
Joe Grandja 2018-07-16 16:12:18 -04:00
parent becff23df1
commit 2cd548221d
4 changed files with 130 additions and 15 deletions

View File

@ -28,6 +28,7 @@ import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAut
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.util.Assert;
@ -147,6 +148,7 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
*/
public class AuthorizationEndpointConfig {
private String authorizationRequestBaseUri;
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private AuthorizationEndpointConfig() {
@ -164,6 +166,18 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
return this;
}
/**
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
*
* @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
* @return the {@link AuthorizationEndpointConfig} for further configuration
*/
public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
this.authorizationRequestResolver = authorizationRequestResolver;
return this;
}
/**
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
*
@ -267,13 +281,19 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>> exte
}
private void configure(B builder, AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer) throws Exception {
String authorizationRequestBaseUri = authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestBaseUri;
if (authorizationRequestBaseUri == null) {
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter;
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), authorizationRequestBaseUri);
if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestResolver != null) {
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestResolver);
} else {
String authorizationRequestBaseUri = authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestBaseUri;
if (authorizationRequestBaseUri == null) {
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), authorizationRequestBaseUri);
}
if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {
authorizationRequestFilter.setAuthorizationRequestRepository(

View File

@ -44,6 +44,7 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@ -178,6 +179,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
*/
public class AuthorizationEndpointConfig {
private String authorizationRequestBaseUri;
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private AuthorizationEndpointConfig() {
@ -195,6 +197,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
return this;
}
/**
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
*
* @since 5.1
* @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
* @return the {@link AuthorizationEndpointConfig} for further configuration
*/
public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
this.authorizationRequestResolver = authorizationRequestResolver;
return this;
}
/**
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
*
@ -444,13 +459,19 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
@Override
public void configure(B http) throws Exception {
String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri;
if (authorizationRequestBaseUri == null) {
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter;
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri);
if (this.authorizationEndpointConfig.authorizationRequestResolver != null) {
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
this.authorizationEndpointConfig.authorizationRequestResolver);
} else {
String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri;
if (authorizationRequestBaseUri == null) {
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri);
}
if (this.authorizationEndpointConfig.authorizationRequestRepository != null) {
authorizationRequestFilter.setAuthorizationRequestRepository(

View File

@ -37,7 +37,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@ -74,6 +76,8 @@ public class OAuth2ClientConfigurerTests {
private static OAuth2AuthorizedClientService authorizedClientService;
private static OAuth2AuthorizationRequestResolver authorizationRequestResolver;
private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private static RequestCache requestCache;
@ -103,6 +107,8 @@ public class OAuth2ClientConfigurerTests {
.build();
clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
clientRegistrationRepository, "/oauth2/authorization");
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
@ -173,6 +179,28 @@ public class OAuth2ClientConfigurerTests {
verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
// gh-5521
@Test
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
// Override default resolver
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
authorizationRequestResolver = request -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("param1", "value1");
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
this.spring.register(OAuth2ClientConfig.class).autowire();
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
.andExpect(status().is3xxRedirection())
.andReturn();
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1&param1=value1");
}
@EnableWebSecurity
@EnableWebMvc
static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter {
@ -188,6 +216,9 @@ public class OAuth2ClientConfigurerTests {
.oauth2()
.client()
.authorizationCodeGrant()
.authorizationEndpoint()
.authorizationRequestResolver(authorizationRequestResolver)
.and()
.tokenEndpoint()
.accessTokenResponseClient(accessTokenResponseClient);
}

View File

@ -42,7 +42,9 @@ import org.springframework.security.oauth2.client.registration.InMemoryClientReg
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@ -105,11 +107,9 @@ public class OAuth2LoginConfigurerTests {
@Before
public void setup() {
this.request = new MockHttpServletRequest("GET", "");
this.request.setServletPath("/login/oauth2/code/google");
this.response = new MockHttpServletResponse();
this.filterChain = new MockFilterChain();
this.request.setMethod("GET");
this.request.setServletPath("/login/oauth2/code/google");
}
@After
@ -225,6 +225,20 @@ public class OAuth2LoginConfigurerTests {
.isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER");
}
// gh-5521
@Test
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
String requestUri = "/oauth2/authorization/google";
this.request = new MockHttpServletRequest("GET", requestUri);
this.request.setServletPath(requestUri);
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
}
@Test
public void oidcLogin() throws Exception {
// setup application context
@ -406,6 +420,35 @@ public class OAuth2LoginConfigurerTests {
}
}
@EnableWebSecurity
static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonWebSecurityConfigurerAdapter {
private ClientRegistrationRepository clientRegistrationRepository =
new InMemoryClientRegistrationRepository(CLIENT_REGISTRATION);
@Override
protected void configure(HttpSecurity http) throws Exception {
http
.oauth2Login()
.clientRegistrationRepository(this.clientRegistrationRepository)
.authorizationEndpoint()
.authorizationRequestResolver(this.getAuthorizationRequestResolver());
super.configure(http);
}
private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
return request -> {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put("custom-param1", "custom-value1");
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
.additionalParameters(additionalParameters)
.build();
};
}
}
private static abstract class CommonWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter {
@Override
protected void configure(HttpSecurity http) throws Exception {