From 2e66b9f6cc6f3c421d3d77f125f12db6bc301dce Mon Sep 17 00:00:00 2001 From: Igor Bolic Date: Fri, 17 Jun 2022 09:42:50 +0200 Subject: [PATCH] Allow customization of redirect strategy The default redirect strategy will provide authorization redirect URI within HTTP 302 response Location header. Allowing the configuration of custom redirect strategy will provide an option for the clients to obtain the authorization URI from e.g. HTTP response body as JSON payload, without a need to handle automatic redirection initiated by the HTTP Location header. Closes gh-11373 --- .../oauth2/client/OAuth2ClientConfigurer.java | 17 ++++ .../oauth2/client/OAuth2LoginConfigurer.java | 18 ++++ .../OAuth2ClientBeanDefinitionParser.java | 14 +++ .../http/OAuth2LoginBeanDefinitionParser.java | 12 +++ .../config/web/server/ServerHttpSecurity.java | 46 ++++++++++ .../client/AuthorizationCodeGrantDsl.kt | 4 + .../oauth2/login/AuthorizationEndpointDsl.kt | 4 + .../web/server/ServerOAuth2ClientDsl.kt | 4 + .../config/web/server/ServerOAuth2LoginDsl.kt | 4 + .../security/config/spring-security-5.8.rnc | 6 ++ .../security/config/spring-security-5.8.xsd | 12 +++ .../security/config/spring-security-6.0.rnc | 6 ++ .../security/config/spring-security-6.0.xsd | 12 +++ .../client/OAuth2ClientConfigurerTests.java | 20 +++++ .../client/OAuth2LoginConfigurerTests.java | 82 +++++++++++++++++ ...OAuth2ClientBeanDefinitionParserTests.java | 15 ++++ .../OAuth2LoginBeanDefinitionParserTests.java | 16 ++++ .../web/server/ServerHttpSecurityTests.java | 89 +++++++++++++++++++ .../client/AuthorizationCodeGrantDslTests.kt | 36 ++++++++ .../login/AuthorizationEndpointDslTests.kt | 33 +++++++ .../web/server/ServerOAuth2ClientDslTests.kt | 37 ++++++++ .../web/server/ServerOAuth2LoginDslTests.kt | 34 +++++++ ...ts-CustomAuthorizationRedirectStrategy.xml | 48 ++++++++++ ...ithCustomAuthorizationRedirectStrategy.xml | 38 ++++++++ .../servlet/appendix/namespace/http.adoc | 10 +++ ...th2AuthorizationRequestRedirectFilter.java | 11 ++- ...AuthorizationRequestRedirectWebFilter.java | 11 ++- ...thorizationRequestRedirectFilterTests.java | 36 ++++++++ ...rizationRequestRedirectWebFilterTests.java | 55 ++++++++++++ 29 files changed, 728 insertions(+), 2 deletions(-) create mode 100644 config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml create mode 100644 config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index a8447e7d14..c81d7b07f8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequest import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.util.Assert; @@ -171,6 +172,8 @@ public final class OAuth2ClientConfigurer> private AuthorizationRequestRepository authorizationRequestRepository; + private RedirectStrategy authorizationRedirectStrategy; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; private AuthorizationCodeGrantConfigurer() { @@ -202,6 +205,17 @@ public final class OAuth2ClientConfigurer> return this; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link AuthorizationCodeGrantConfigurer} for further configuration + */ + public AuthorizationCodeGrantConfigurer authorizationRedirectStrategy( + RedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + /** * Sets the client used for requesting the access token credential from the Token * Endpoint. @@ -247,6 +261,9 @@ public final class OAuth2ClientConfigurer> authorizationRequestRedirectFilter .setAuthorizationRequestRepository(this.authorizationRequestRepository); } + if (this.authorizationRedirectStrategy != null) { + authorizationRequestRedirectFilter.setAuthorizationRedirectStrategy(this.authorizationRedirectStrategy); + } RequestCache requestCache = builder.getSharedObject(RequestCache.class); if (requestCache != null) { authorizationRequestRedirectFilter.setRequestCache(requestCache); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index dcdf53f121..e26e12c14d 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -64,6 +64,7 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.web.AuthenticationEntryPoint; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; @@ -363,6 +364,10 @@ public final class OAuth2LoginConfigurer> authorizationRequestFilter .setAuthorizationRequestRepository(this.authorizationEndpointConfig.authorizationRequestRepository); } + if (this.authorizationEndpointConfig.authorizationRedirectStrategy != null) { + authorizationRequestFilter + .setAuthorizationRedirectStrategy(this.authorizationEndpointConfig.authorizationRedirectStrategy); + } RequestCache requestCache = http.getSharedObject(RequestCache.class); if (requestCache != null) { authorizationRequestFilter.setRequestCache(requestCache); @@ -526,6 +531,8 @@ public final class OAuth2LoginConfigurer> private AuthorizationRequestRepository authorizationRequestRepository; + private RedirectStrategy authorizationRedirectStrategy; + private AuthorizationEndpointConfig() { } @@ -568,6 +575,17 @@ public final class OAuth2LoginConfigurer> return this; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link AuthorizationEndpointConfig} for further configuration + */ + public AuthorizationEndpointConfig authorizationRedirectStrategy( + RedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + /** * Returns the {@link OAuth2LoginConfigurer} for further configuration. * @return the {@link OAuth2LoginConfigurer} diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java index f2c1ebd0f0..5f039548fe 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java @@ -44,6 +44,8 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; + private static final String ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF = "authorization-redirect-strategy-ref"; + private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; private final BeanReference requestCache; @@ -83,6 +85,7 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { } BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository( authorizationCodeGrantElt); + BeanMetadataElement authorizationRedirectStrategy = getAuthorizationRedirectStrategy(authorizationCodeGrantElt); BeanDefinitionBuilder authorizationRequestRedirectFilterBuilder = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationRequestRedirectFilter.class); String authorizationRequestResolverRef = (authorizationCodeGrantElt != null) @@ -95,6 +98,7 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { } this.authorizationRequestRedirectFilter = authorizationRequestRedirectFilterBuilder .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) + .addPropertyValue("authorizationRedirectStrategy", authorizationRedirectStrategy) .addPropertyValue("requestCache", this.requestCache).getBeanDefinition(); BeanDefinitionBuilder authorizationCodeGrantFilterBldr = BeanDefinitionBuilder .rootBeanDefinition(OAuth2AuthorizationCodeGrantFilter.class) @@ -126,6 +130,16 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser { .getBeanDefinition(); } + private BeanMetadataElement getAuthorizationRedirectStrategy(Element element) { + String authorizationRedirectStrategyRef = (element != null) + ? element.getAttribute(ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF) : null; + if (StringUtils.hasText(authorizationRedirectStrategyRef)) { + return new RuntimeBeanReference(authorizationRedirectStrategyRef); + } + return BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.DefaultRedirectStrategy") + .getBeanDefinition(); + } + private BeanMetadataElement getAccessTokenResponseClient(Element element) { String accessTokenResponseClientRef = (element != null) ? element.getAttribute(ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF) : null; diff --git a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java index 288b09072e..1b8efc6695 100644 --- a/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java +++ b/config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java @@ -87,6 +87,8 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref"; + private static final String ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF = "authorization-redirect-strategy-ref"; + private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref"; private static final String ATT_USER_AUTHORITIES_MAPPER_REF = "user-authorities-mapper-ref"; @@ -199,6 +201,7 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { } oauth2AuthorizationRequestRedirectFilterBuilder .addPropertyValue("authorizationRequestRepository", authorizationRequestRepository) + .addPropertyValue("authorizationRedirectStrategy", getAuthorizationRedirectStrategy(element)) .addPropertyValue("requestCache", this.requestCache); this.oauth2AuthorizationRequestRedirectFilter = oauth2AuthorizationRequestRedirectFilterBuilder .getBeanDefinition(); @@ -261,6 +264,15 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser { .getBeanDefinition(); } + private BeanMetadataElement getAuthorizationRedirectStrategy(Element element) { + String authorizationRedirectStrategyRef = element.getAttribute(ATT_AUTHORIZATION_REDIRECT_STRATEGY_REF); + if (StringUtils.hasText(authorizationRedirectStrategyRef)) { + return new RuntimeBeanReference(authorizationRedirectStrategyRef); + } + return BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.DefaultRedirectStrategy") + .getBeanDefinition(); + } + private BeanDefinition getOidcAuthProvider(Element element, BeanMetadataElement accessTokenResponseClient, String userAuthoritiesMapperRef) { boolean oidcAuthenticationProviderEnabled = ClassUtils diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 880c22630b..d87bd03d86 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -102,12 +102,14 @@ import org.springframework.security.oauth2.server.resource.web.server.ServerBear import org.springframework.security.web.PortMapper; import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; +import org.springframework.security.web.server.DefaultServerRedirectStrategy; import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint; import org.springframework.security.web.server.DelegatingServerAuthenticationEntryPoint.DelegateEntry; import org.springframework.security.web.server.ExchangeMatcherRedirectWebFilter; import org.springframework.security.web.server.MatcherSecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; +import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilter; import org.springframework.security.web.server.authentication.AuthenticationConverterServerWebExchangeMatcher; import org.springframework.security.web.server.authentication.AuthenticationWebFilter; @@ -3375,6 +3377,8 @@ public class ServerHttpSecurity { private ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; + private ServerRedirectStrategy authorizationRedirectStrategy; + private ServerWebExchangeMatcher authenticationMatcher; private ServerAuthenticationSuccessHandler authenticationSuccessHandler; @@ -3547,6 +3551,16 @@ public class ServerHttpSecurity { return this; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link OAuth2LoginSpec} for further configuration + */ + public OAuth2LoginSpec authorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + /** * Sets the {@link ServerWebExchangeMatcher matcher} used for determining if the * request is an authentication request. @@ -3581,7 +3595,9 @@ public class ServerHttpSecurity { OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter(); ServerAuthorizationRequestRepository authorizationRequestRepository = getAuthorizationRequestRepository(); oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository); + oauthRedirectFilter.setAuthorizationRedirectStrategy(getAuthorizationRedirectStrategy()); oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); + ReactiveAuthenticationManager manager = getAuthenticationManager(); AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository); @@ -3591,6 +3607,7 @@ public class ServerHttpSecurity { authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http)); authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler()); authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + setDefaultEntryPoints(http); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION); @@ -3737,6 +3754,13 @@ public class ServerHttpSecurity { return this.authorizationRequestRepository; } + private ServerRedirectStrategy getAuthorizationRedirectStrategy() { + if (this.authorizationRedirectStrategy == null) { + this.authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + } + return this.authorizationRedirectStrategy; + } + private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() { ReactiveOAuth2AuthorizedClientService bean = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class); if (bean != null) { @@ -3759,6 +3783,8 @@ public class ServerHttpSecurity { private ServerAuthorizationRequestRepository authorizationRequestRepository; + private ServerRedirectStrategy authorizationRedirectStrategy; + private OAuth2ClientSpec() { } @@ -3851,6 +3877,23 @@ public class ServerHttpSecurity { return this.authorizationRequestRepository; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + * @return the {@link OAuth2ClientSpec} for further configuration + */ + public OAuth2ClientSpec authorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) { + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + return this; + } + + private ServerRedirectStrategy getAuthorizationRedirectStrategy() { + if (this.authorizationRedirectStrategy == null) { + this.authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + } + return this.authorizationRedirectStrategy; + } + /** * Allows method chaining to continue configuring the {@link ServerHttpSecurity} * @return the {@link ServerHttpSecurity} to continue configuring @@ -3870,12 +3913,15 @@ public class ServerHttpSecurity { if (http.requestCache != null) { codeGrantWebFilter.setRequestCache(http.requestCache.requestCache); } + OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter( clientRegistrationRepository); oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository()); + oauthRedirectFilter.setAuthorizationRedirectStrategy(getAuthorizationRedirectStrategy()); if (http.requestCache != null) { oauthRedirectFilter.setRequestCache(http.requestCache.requestCache); } + http.addFilterAt(codeGrantWebFilter, SecurityWebFiltersOrder.OAUTH2_AUTHORIZATION_CODE); http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC); } diff --git a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt index 35356b09e1..735d75c0b3 100644 --- a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDsl.kt @@ -23,6 +23,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCo import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.RedirectStrategy /** * A Kotlin DSL to configure OAuth 2.0 Authorization Code Grant. @@ -31,6 +32,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ * @since 5.3 * @property authorizationRequestResolver the resolver used for resolving [OAuth2AuthorizationRequest]'s. * @property authorizationRequestRepository the repository used for storing [OAuth2AuthorizationRequest]'s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. * @property accessTokenResponseClient the client used for requesting the access token credential * from the Token Endpoint. */ @@ -38,12 +40,14 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ class AuthorizationCodeGrantDsl { var authorizationRequestResolver: OAuth2AuthorizationRequestResolver? = null var authorizationRequestRepository: AuthorizationRequestRepository? = null + var authorizationRedirectStrategy: RedirectStrategy? = null var accessTokenResponseClient: OAuth2AccessTokenResponseClient? = null internal fun get(): (OAuth2ClientConfigurer.AuthorizationCodeGrantConfigurer) -> Unit { return { authorizationCodeGrant -> authorizationRequestResolver?.also { authorizationCodeGrant.authorizationRequestResolver(authorizationRequestResolver) } authorizationRequestRepository?.also { authorizationCodeGrant.authorizationRequestRepository(authorizationRequestRepository) } + authorizationRedirectStrategy?.also { authorizationCodeGrant.authorizationRedirectStrategy(authorizationRedirectStrategy) } accessTokenResponseClient?.also { authorizationCodeGrant.accessTokenResponseClient(accessTokenResponseClient) } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt index 96289fa825..160efb9081 100644 --- a/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDsl.kt @@ -21,6 +21,7 @@ import org.springframework.security.config.annotation.web.configurers.oauth2.cli import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.RedirectStrategy /** * A Kotlin DSL to configure the Authorization Server's Authorization Endpoint using @@ -31,18 +32,21 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ * @property baseUri the base URI used for authorization requests. * @property authorizationRequestResolver the resolver used for resolving [OAuth2AuthorizationRequest]'s. * @property authorizationRequestRepository the repository used for storing [OAuth2AuthorizationRequest]'s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. */ @OAuth2LoginSecurityMarker class AuthorizationEndpointDsl { var baseUri: String? = null var authorizationRequestResolver: OAuth2AuthorizationRequestResolver? = null var authorizationRequestRepository: AuthorizationRequestRepository? = null + var authorizationRedirectStrategy: RedirectStrategy? = null internal fun get(): (OAuth2LoginConfigurer.AuthorizationEndpointConfig) -> Unit { return { authorizationEndpoint -> baseUri?.also { authorizationEndpoint.baseUri(baseUri) } authorizationRequestResolver?.also { authorizationEndpoint.authorizationRequestResolver(authorizationRequestResolver) } authorizationRequestRepository?.also { authorizationEndpoint.authorizationRequestRepository(authorizationRequestRepository) } + authorizationRedirectStrategy?.also { authorizationEndpoint.authorizationRedirectStrategy(authorizationRedirectStrategy) } } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt index 6751d24296..edd50e3345 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDsl.kt @@ -22,6 +22,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.web.server.ServerWebExchange @@ -37,6 +38,7 @@ import org.springframework.web.server.ServerWebExchange * @property clientRegistrationRepository the repository of client registrations. * @property authorizedClientRepository the repository for authorized client(s). * @property authorizationRequestRepository the repository to use for storing [OAuth2AuthorizationRequest]s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. */ @ServerSecurityMarker class ServerOAuth2ClientDsl { @@ -45,6 +47,7 @@ class ServerOAuth2ClientDsl { var clientRegistrationRepository: ReactiveClientRegistrationRepository? = null var authorizedClientRepository: ServerOAuth2AuthorizedClientRepository? = null var authorizationRequestRepository: ServerAuthorizationRequestRepository? = null + var authorizationRedirectStrategy: ServerRedirectStrategy? = null internal fun get(): (ServerHttpSecurity.OAuth2ClientSpec) -> Unit { return { oauth2Client -> @@ -53,6 +56,7 @@ class ServerOAuth2ClientDsl { clientRegistrationRepository?.also { oauth2Client.clientRegistrationRepository(clientRegistrationRepository) } authorizedClientRepository?.also { oauth2Client.authorizedClientRepository(authorizedClientRepository) } authorizationRequestRepository?.also { oauth2Client.authorizationRequestRepository(authorizationRequestRepository) } + authorizationRedirectStrategy?.also { oauth2Client.authorizationRedirectStrategy(authorizationRedirectStrategy) } } } } diff --git a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt index 0c24340fbb..4ab8fcb0e4 100644 --- a/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDsl.kt @@ -24,6 +24,7 @@ import org.springframework.security.oauth2.client.web.server.ServerAuthorization import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler @@ -49,6 +50,7 @@ import org.springframework.web.server.ServerWebExchange * @property authorizedClientRepository the repository for authorized client(s). * @property authorizationRequestRepository the repository to use for storing [OAuth2AuthorizationRequest]s. * @property authorizationRequestResolver the resolver used for resolving [OAuth2AuthorizationRequest]s. + * @property authorizationRedirectStrategy the redirect strategy for Authorization Endpoint redirect URI. * @property authenticationMatcher the [ServerWebExchangeMatcher] used for determining if the request is an * authentication request. */ @@ -64,6 +66,7 @@ class ServerOAuth2LoginDsl { var authorizedClientRepository: ServerOAuth2AuthorizedClientRepository? = null var authorizationRequestRepository: ServerAuthorizationRequestRepository? = null var authorizationRequestResolver: ServerOAuth2AuthorizationRequestResolver? = null + var authorizationRedirectStrategy: ServerRedirectStrategy? = null var authenticationMatcher: ServerWebExchangeMatcher? = null internal fun get(): (ServerHttpSecurity.OAuth2LoginSpec) -> Unit { @@ -78,6 +81,7 @@ class ServerOAuth2LoginDsl { authorizedClientRepository?.also { oauth2Login.authorizedClientRepository(authorizedClientRepository) } authorizationRequestRepository?.also { oauth2Login.authorizationRequestRepository(authorizationRequestRepository) } authorizationRequestResolver?.also { oauth2Login.authorizationRequestResolver(authorizationRequestResolver) } + authorizationRedirectStrategy?.also { oauth2Login.authorizationRedirectStrategy(authorizationRedirectStrategy) } authenticationMatcher?.also { oauth2Login.authenticationMatcher(authenticationMatcher) } } } diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc index a6b80f163e..36afa1b42b 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc @@ -500,6 +500,9 @@ oauth2-login.attlist &= oauth2-login.attlist &= ## Reference to the OAuth2AuthorizationRequestResolver attribute authorization-request-resolver-ref {xsd:token}? +oauth2-login.attlist &= + ## Reference to the authorization RedirectStrategy + attribute authorization-redirect-strategy-ref {xsd:token}? oauth2-login.attlist &= ## Reference to the OAuth2AccessTokenResponseClient attribute access-token-response-client-ref {xsd:token}? @@ -547,6 +550,9 @@ authorization-code-grant = authorization-code-grant.attlist &= ## Reference to the AuthorizationRequestRepository attribute authorization-request-repository-ref {xsd:token}? +authorization-code-grant.attlist &= + ## Reference to the authorization RedirectStrategy + attribute authorization-redirect-strategy-ref {xsd:token}? authorization-code-grant.attlist &= ## Reference to the OAuth2AuthorizationRequestResolver attribute authorization-request-resolver-ref {xsd:token}? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd index 4d5e16a304..2b98c8e68d 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd @@ -1651,6 +1651,12 @@ + + + Reference to the authorization RedirectStrategy + + + Reference to the OAuth2AccessTokenResponseClient @@ -1754,6 +1760,12 @@ + + + Reference to the authorization RedirectStrategy + + + Reference to the OAuth2AuthorizationRequestResolver diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc index ca2d5b356c..89e1b4c9cc 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc +++ b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc @@ -500,6 +500,9 @@ oauth2-login.attlist &= oauth2-login.attlist &= ## Reference to the OAuth2AuthorizationRequestResolver attribute authorization-request-resolver-ref {xsd:token}? +oauth2-login.attlist &= + ## Reference to the authorization RedirectStrategy + attribute authorization-redirect-strategy-ref {xsd:token}? oauth2-login.attlist &= ## Reference to the OAuth2AccessTokenResponseClient attribute access-token-response-client-ref {xsd:token}? @@ -547,6 +550,9 @@ authorization-code-grant = authorization-code-grant.attlist &= ## Reference to the AuthorizationRequestRepository attribute authorization-request-repository-ref {xsd:token}? +authorization-code-grant.attlist &= + ## Reference to the authorization RedirectStrategy + attribute authorization-redirect-strategy-ref {xsd:token}? authorization-code-grant.attlist &= ## Reference to the OAuth2AuthorizationRequestResolver attribute authorization-request-resolver-ref {xsd:token}? diff --git a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd index 5b9193ee3f..f64f686774 100644 --- a/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd +++ b/config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd @@ -1629,6 +1629,12 @@ + + + Reference to the authorization RedirectStrategy + + + Reference to the OAuth2AccessTokenResponseClient @@ -1732,6 +1738,12 @@ + + + Reference to the authorization RedirectStrategy + + + Reference to the OAuth2AuthorizationRequestResolver diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index f2e6a85f9a..faeaceed2c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -58,6 +58,8 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -68,6 +70,7 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -95,6 +98,8 @@ public class OAuth2ClientConfigurerTests { private static OAuth2AuthorizationRequestResolver authorizationRequestResolver; + private static RedirectStrategy authorizationRedirectStrategy; + private static OAuth2AccessTokenResponseClient accessTokenResponseClient; private static RequestCache requestCache; @@ -130,6 +135,7 @@ public class OAuth2ClientConfigurerTests { authorizedClientService); authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, "/oauth2/authorization"); + authorizationRedirectStrategy = new DefaultRedirectStrategy(); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(300).build(); accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); @@ -261,6 +267,19 @@ public class OAuth2ClientConfigurerTests { verify(authorizationRequestResolver).resolve(any()); } + @Test + public void configureWhenCustomAuthorizationRedirectStrategySetThenAuthorizationRedirectStrategyUsed() + throws Exception { + authorizationRedirectStrategy = mock(RedirectStrategy.class); + this.spring.register(OAuth2ClientConfig.class).autowire(); + // @formatter:off + this.mockMvc.perform(get("/oauth2/authorization/registration-1")) + .andExpect(status().isOk()) + .andReturn(); + // @formatter:on + verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); + } + @EnableWebSecurity @EnableWebMvc static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter { @@ -278,6 +297,7 @@ public class OAuth2ClientConfigurerTests { .oauth2Client() .authorizationCodeGrant() .authorizationRequestResolver(authorizationRequestResolver) + .authorizationRedirectStrategy(authorizationRedirectStrategy) .accessTokenResponseClient(accessTokenResponseClient); // @formatter:on } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index d591b9b107..b0f5a73ae9 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -87,6 +87,7 @@ import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.web.FilterChainProxy; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; @@ -98,7 +99,9 @@ import org.springframework.web.context.support.AnnotationConfigWebApplicationCon import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -357,6 +360,32 @@ public class OAuth2LoginConfigurerTests { "https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=clientId&scope=openid+profile+email&state=state&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1"); } + @Test + public void oauth2LoginWithAuthorizationRedirectStrategyThenCustomAuthorizationRedirectStrategyUsed() + throws Exception { + loadConfig(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class); + RedirectStrategy redirectStrategy = this.context + .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategy.class).redirectStrategy; + String requestUri = "/oauth2/authorization/google"; + this.request = new MockHttpServletRequest("GET", requestUri); + this.request.setServletPath(requestUri); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); + } + + @Test + public void requestWhenOauth2LoginWithCustomAuthorizationRedirectStrategyThenCustomAuthorizationRedirectStrategyUsed() + throws Exception { + loadConfig(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class); + RedirectStrategy redirectStrategy = this.context + .getBean(OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda.class).redirectStrategy; + String requestUri = "/oauth2/authorization/google"; + this.request = new MockHttpServletRequest("GET", requestUri); + this.request.setServletPath(requestUri); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + then(redirectStrategy).should().sendRedirect(any(), any(), anyString()); + } + // gh-5347 @Test public void oauth2LoginWithOneClientConfiguredThenRedirectForAuthorization() throws Exception { @@ -858,6 +887,59 @@ public class OAuth2LoginConfigurerTests { } + @EnableWebSecurity + static class OAuth2LoginConfigCustomAuthorizationRedirectStrategy extends CommonWebSecurityConfigurerAdapter { + + private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); + + RedirectStrategy redirectStrategy = mock(RedirectStrategy.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> + oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizationEndpoint((authorizationEndpoint) -> + authorizationEndpoint + .authorizationRedirectStrategy(this.redirectStrategy) + ) + ); + // @formatter:on + super.configure(http); + } + + } + + @EnableWebSecurity + static class OAuth2LoginConfigCustomAuthorizationRedirectStrategyInLambda + extends CommonLambdaWebSecurityConfigurerAdapter { + + private final ClientRegistrationRepository clientRegistrationRepository = new InMemoryClientRegistrationRepository( + GOOGLE_CLIENT_REGISTRATION); + + RedirectStrategy redirectStrategy = mock(RedirectStrategy.class); + + @Override + protected void configure(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> + oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizationEndpoint((authorizationEndpoint) -> + authorizationEndpoint + .authorizationRedirectStrategy(this.redirectStrategy) + ) + ); + // @formatter:on + super.configure(http); + } + + } + @EnableWebSecurity static class OAuth2LoginConfigMultipleClients extends CommonWebSecurityConfigurerAdapter { diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java index 0e9806118a..f3dbab941c 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java @@ -44,6 +44,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.security.web.RedirectStrategy; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -55,6 +56,7 @@ import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -90,6 +92,9 @@ public class OAuth2ClientBeanDefinitionParserTests { @Autowired(required = false) private OAuth2AuthorizationRequestResolver authorizationRequestResolver; + @Autowired(required = false) + private RedirectStrategy authorizationRedirectStrategy; + @Autowired(required = false) private OAuth2AccessTokenResponseClient accessTokenResponseClient; @@ -148,6 +153,16 @@ public class OAuth2ClientBeanDefinitionParserTests { verify(this.authorizationRequestResolver).resolve(any()); } + @Test + public void requestWhenCustomAuthorizationRedirectStrategyThenCalled() throws Exception { + this.spring.configLocations(xml("CustomAuthorizationRedirectStrategy")).autowire(); + // @formatter:off + this.mvc.perform(get("/oauth2/authorization/google")) + .andExpect(status().isOk()); + // @formatter:on + verify(this.authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); + } + @Test public void requestWhenAuthorizationResponseMatchThenProcess() throws Exception { this.spring.configLocations(xml("CustomConfiguration")).autowire(); diff --git a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java index 38f43a0911..8b98a9a9a0 100644 --- a/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java @@ -63,6 +63,7 @@ import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners; import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.savedrequest.RequestCache; @@ -77,6 +78,7 @@ import org.springframework.web.bind.annotation.RestController; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -116,6 +118,9 @@ public class OAuth2LoginBeanDefinitionParserTests { @Autowired(required = false) private OAuth2AuthorizationRequestResolver authorizationRequestResolver; + @Autowired(required = false) + private RedirectStrategy authorizationRedirectStrategy; + @Autowired(required = false) private OAuth2AccessTokenResponseClient accessTokenResponseClient; @@ -373,6 +378,17 @@ public class OAuth2LoginBeanDefinitionParserTests { verify(this.authorizationRequestResolver).resolve(any()); } + @Test + public void requestWhenCustomAuthorizationRedirectStrategyThenCalled() throws Exception { + this.spring.configLocations(this.xml("SingleClientRegistration-WithCustomAuthorizationRedirectStrategy")) + .autowire(); + // @formatter:off + this.mvc.perform(get("/oauth2/authorization/google-login")) + .andExpect(status().isOk()); + // @formatter:on + verify(this.authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); + } + // gh-5347 @Test public void requestWhenMultiClientRegistrationThenRedirectDefaultLoginPage() throws Exception { diff --git a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java index 3b8c6d97f8..be097e4b07 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java @@ -39,14 +39,18 @@ import org.springframework.security.config.annotation.web.reactive.ServerHttpSec import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter; import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; +import org.springframework.security.web.server.DefaultServerRedirectStrategy; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.ServerAuthenticationEntryPoint; +import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests; import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint; @@ -76,6 +80,7 @@ import org.springframework.web.server.WebFilterChain; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -531,6 +536,90 @@ public class ServerHttpSecurityTests { verify(authorizationRequestRepository).removeAuthorizationRequest(any()); } + @Test + public void shouldUseDefaultAuthorizationRedirectStrategyForOAuth2Login() { + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange().expectStatus().is3xxRedirection(); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isInstanceOf(DefaultServerRedirectStrategy.class); + } + + @Test + public void shouldConfigureAuthorizationRedirectStrategyForOAuth2Login() { + ServerRedirectStrategy authorizationRedirectStrategy = mock(ServerRedirectStrategy.class); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + given(authorizationRedirectStrategy.sendRedirect(any(), any())).willReturn(Mono.empty()); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Login() + .clientRegistrationRepository(clientRegistrationRepository) + .authorizationRedirectStrategy(authorizationRedirectStrategy).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange(); + verify(authorizationRedirectStrategy).sendRedirect(any(), any()); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isSameAs(authorizationRedirectStrategy); + } + + @Test + public void shouldUseDefaultAuthorizationRedirectStrategyForOAuth2Client() { + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Client() + .clientRegistrationRepository(clientRegistrationRepository).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange().expectStatus().is3xxRedirection(); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isInstanceOf(DefaultServerRedirectStrategy.class); + } + + @Test + public void shouldConfigureAuthorizationRedirectStrategyForOAuth2Client() { + ServerRedirectStrategy authorizationRedirectStrategy = mock(ServerRedirectStrategy.class); + ReactiveClientRegistrationRepository clientRegistrationRepository = mock( + ReactiveClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(anyString())) + .willReturn(Mono.just(TestClientRegistrations.clientRegistration().build())); + given(authorizationRedirectStrategy.sendRedirect(any(), any())).willReturn(Mono.empty()); + + SecurityWebFilterChain securityFilterChain = this.http.oauth2Client() + .clientRegistrationRepository(clientRegistrationRepository) + .authorizationRedirectStrategy(authorizationRedirectStrategy).and().build(); + + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); + client.get().uri("/oauth2/authorization/registration-id").exchange(); + verify(authorizationRedirectStrategy).sendRedirect(any(), any()); + + OAuth2AuthorizationRequestRedirectWebFilter filter = getWebFilter(securityFilterChain, + OAuth2AuthorizationRequestRedirectWebFilter.class).get(); + assertThat(ReflectionTestUtils.getField(filter, "authorizationRedirectStrategy")) + .isSameAs(authorizationRedirectStrategy); + } + private boolean isX509Filter(WebFilter filter) { try { Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt index f4cb3ed6e5..8d2cec76c1 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/client/AuthorizationCodeGrantDslTests.kt @@ -43,6 +43,8 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames +import org.springframework.security.web.DefaultRedirectStrategy +import org.springframework.security.web.RedirectStrategy import org.springframework.security.web.SecurityFilterChain import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.get @@ -104,6 +106,40 @@ class AuthorizationCodeGrantDslTests { } } + @Test + fun `oauth2Client when custom authorization redirect strategy then redirect strategy used`() { + this.spring.register(RedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(RedirectStrategyConfig.REDIRECT_STRATEGY) + every { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + + this.mockMvc.get("/oauth2/authorization/registrationId") + + verify(exactly = 1) { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + } + + @EnableWebSecurity + open class RedirectStrategyConfig { + + companion object { + val REDIRECT_STRATEGY: RedirectStrategy = DefaultRedirectStrategy() + } + + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + oauth2Client { + authorizationCodeGrant { + authorizationRedirectStrategy = REDIRECT_STRATEGY + } + } + authorizeRequests { + authorize(anyRequest, authenticated) + } + } + return http.build() + } + } + @Test fun `oauth2Client when custom access token response client then client used`() { this.spring.register(AuthorizedClientConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt index 5571688f5a..1801f5d954 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/oauth2/login/AuthorizationEndpointDslTests.kt @@ -37,6 +37,8 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestReposi import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest +import org.springframework.security.web.DefaultRedirectStrategy +import org.springframework.security.web.RedirectStrategy import org.springframework.security.web.SecurityFilterChain import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.get @@ -125,6 +127,37 @@ class AuthorizationEndpointDslTests { } } + @Test + fun `oauth2Login when custom authorization redirect strategy then redirect strategy used`() { + this.spring.register(RedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(RedirectStrategyConfig.REDIRECT_STRATEGY) + every { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + + this.mockMvc.get("/oauth2/authorization/google") + + verify(exactly = 1) { RedirectStrategyConfig.REDIRECT_STRATEGY.sendRedirect(any(), any(), any()) } + } + + @EnableWebSecurity + open class RedirectStrategyConfig { + + companion object { + val REDIRECT_STRATEGY: RedirectStrategy = DefaultRedirectStrategy() + } + + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + oauth2Login { + authorizationEndpoint { + authorizationRedirectStrategy = REDIRECT_STRATEGY + } + } + } + return http.build() + } + } + @Test fun `oauth2Login when custom authorization uri repository then uri used`() { this.spring.register(AuthorizationUriConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt index e93816ebdb..a1dc851c1c 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ClientDslTests.kt @@ -39,7 +39,9 @@ import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2Ser import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames import org.springframework.security.oauth2.server.resource.web.server.ServerBearerTokenAuthenticationConverter +import org.springframework.security.web.server.DefaultServerRedirectStrategy import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.web.reactive.config.EnableWebFlux @@ -130,6 +132,41 @@ class ServerOAuth2ClientDslTests { } } + @Test + fun `OAuth2 client when authorization redirect strategy configured then custom redirect strategy used`() { + this.spring.register(AuthorizationRedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY) + every { + AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) + } returns Mono.empty() + + this.client.get() + .uri("/oauth2/authorization/google") + .exchange() + + verify(exactly = 1) { + AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) + } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class AuthorizationRedirectStrategyConfig { + + companion object { + val AUTHORIZATION_REDIRECT_STRATEGY : ServerRedirectStrategy = DefaultServerRedirectStrategy() + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + oauth2Client { + authorizationRedirectStrategy = AUTHORIZATION_REDIRECT_STRATEGY + } + } + } + } + @Test fun `OAuth2 client when authentication converter configured then custom converter used`() { this.spring.register(AuthenticationConverterConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt index 8c73b263e2..5fd23b57aa 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2LoginDslTests.kt @@ -35,7 +35,9 @@ import org.springframework.security.oauth2.client.web.server.ServerAuthorization import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest import org.springframework.security.oauth2.server.resource.web.server.ServerBearerTokenAuthenticationConverter +import org.springframework.security.web.server.DefaultServerRedirectStrategy import org.springframework.security.web.server.SecurityWebFilterChain +import org.springframework.security.web.server.ServerRedirectStrategy import org.springframework.security.web.server.authentication.ServerAuthenticationConverter import org.springframework.security.web.server.util.matcher.IpAddressServerWebExchangeMatcher import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher @@ -141,6 +143,38 @@ class ServerOAuth2LoginDslTests { } } + @Test + fun `OAuth2 login when authorization redirect strategy configured then custom redirect strategy used`() { + this.spring.register(AuthorizationRedirectStrategyConfig::class.java, ClientConfig::class.java).autowire() + mockkObject(AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY) + every { + AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) + } returns Mono.empty() + this.client.get() + .uri("/oauth2/authorization/google") + .exchange() + + verify(exactly = 1) { AuthorizationRedirectStrategyConfig.AUTHORIZATION_REDIRECT_STRATEGY.sendRedirect(any(), any()) } + } + + @EnableWebFluxSecurity + @EnableWebFlux + open class AuthorizationRedirectStrategyConfig { + + companion object { + val AUTHORIZATION_REDIRECT_STRATEGY : ServerRedirectStrategy = DefaultServerRedirectStrategy() + } + + @Bean + open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain { + return http { + oauth2Login { + authorizationRedirectStrategy = AUTHORIZATION_REDIRECT_STRATEGY + } + } + } + } + @Test fun `OAuth2 login when authentication matcher configured then custom matcher used`() { this.spring.register(AuthenticationMatcherConfig::class.java, ClientConfig::class.java).autowire() diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml new file mode 100644 index 0000000000..d7ff413909 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-CustomAuthorizationRedirectStrategy.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml new file mode 100644 index 0000000000..8454de28f1 --- /dev/null +++ b/config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-SingleClientRegistration-WithCustomAuthorizationRedirectStrategy.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + diff --git a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc index 0a50f43257..fd3788576a 100644 --- a/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc +++ b/docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc @@ -983,6 +983,11 @@ Reference to the `AuthorizationRequestRepository`. Reference to the `OAuth2AuthorizationRequestResolver`. +[[nsa-oauth2-login-authorization-redirect-strategy-ref]] +* **authorization-redirect-strategy-ref** +Reference to the authorization `RedirectStrategy`. + + [[nsa-oauth2-login-access-token-response-client-ref]] * **access-token-response-client-ref** Reference to the `OAuth2AccessTokenResponseClient`. @@ -1083,6 +1088,11 @@ Configures xref:servlet/oauth2/client/authorization-grants.adoc#oauth2Client-aut Reference to the `AuthorizationRequestRepository`. +[[nsa-authorization-code-grant-authorization-redirect-strategy-ref]] +* **authorization-redirect-strategy-ref** +Reference to the authorization `RedirectStrategy`. + + [[nsa-authorization-code-grant-authorization-request-resolver-ref]] * **authorization-request-resolver-ref** Reference to the `OAuth2AuthorizationRequestResolver`. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java index 28a6350b94..c5779e888d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java @@ -89,7 +89,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt private final ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); - private final RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); + private RedirectStrategy authorizationRedirectStrategy = new DefaultRedirectStrategy(); private OAuth2AuthorizationRequestResolver authorizationRequestResolver; @@ -133,6 +133,15 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt this.authorizationRequestResolver = authorizationRequestResolver; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + */ + public void setAuthorizationRedirectStrategy(RedirectStrategy authorizationRedirectStrategy) { + Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be null"); + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + } + /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. * @param authorizationRequestRepository the repository used for storing diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java index deab87c078..81a2e4118e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java @@ -69,7 +69,7 @@ import org.springframework.web.util.UriComponentsBuilder; */ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { - private final ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); + private ServerRedirectStrategy authorizationRedirectStrategy = new DefaultServerRedirectStrategy(); private final ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver; @@ -99,6 +99,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter { this.authorizationRequestResolver = authorizationRequestResolver; } + /** + * Sets the redirect strategy for Authorization Endpoint redirect URI. + * @param authorizationRedirectStrategy the redirect strategy + */ + public void setAuthorizationRedirectStrategy(ServerRedirectStrategy authorizationRedirectStrategy) { + Assert.notNull(authorizationRedirectStrategy, "authorizationRedirectStrategy cannot be null"); + this.authorizationRedirectStrategy = authorizationRedirectStrategy; + } + /** * Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s. * @param authorizationRequestRepository the repository used for storing diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java index c4928dde74..e2d74aa781 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web; import java.lang.reflect.Constructor; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -29,7 +30,9 @@ import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; @@ -38,6 +41,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.util.ClassUtils; import org.springframework.web.util.UriComponentsBuilder; @@ -104,6 +108,11 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null)); } + @Test + public void setAuthorizationRedirectStrategyWhenAuthorizationRedirectStrategyIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRedirectStrategy(null)); + } + @Test public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null)); @@ -289,4 +298,31 @@ public class OAuth2AuthorizationRequestRedirectFilterTests { + "login_hint=user@provider\\.com"); } + @Test + public void doFilterWhenCustomAuthorizationRedirectStrategySetThenCustomAuthorizationRedirectStrategyUsed() + throws Exception { + String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/" + + this.registration1.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + RedirectStrategy customRedirectStrategy = (httpRequest, httpResponse, url) -> { + String redirectUrl = httpResponse.encodeRedirectURL(url); + httpResponse.setStatus(HttpStatus.OK.value()); + httpResponse.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE); + httpResponse.getWriter().write(redirectUrl); + httpResponse.getWriter().flush(); + }; + this.filter.setAuthorizationRedirectStrategy(customRedirectStrategy); + this.filter.doFilter(request, response, filterChain); + verifyZeroInteractions(filterChain); + assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); + assertThat(response.getContentType()).isEqualTo(MediaType.TEXT_PLAIN_VALUE); + assertThat(response.getContentAsString(StandardCharsets.UTF_8)) + .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" + + "scope=read:user&state=.{15,}&" + + "redirect_uri=http://localhost/login/oauth2/code/registration-id"); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java index d515c73e06..1821cc140f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.client.web.server; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.junit.jupiter.api.BeforeEach; @@ -24,13 +25,20 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.savedrequest.ServerRequestCache; import org.springframework.test.web.reactive.server.FluxExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; @@ -81,6 +89,11 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { .isThrownBy(() -> new OAuth2AuthorizationRequestRedirectWebFilter(this.clientRepository)); } + @Test + public void setterWhenAuthorizationRedirectStrategyNullThenIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRedirectStrategy(null)); + } + @Test public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() { // @formatter:off @@ -195,4 +208,46 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests { verifyNoInteractions(this.requestCache); } + @Test + public void filterWhenCustomRedirectStrategySetThenRedirectUriInResponseBody() { + given(this.clientRepository.findByRegistrationId(this.registration.getRegistrationId())) + .willReturn(Mono.just(this.registration)); + given(this.authzRequestRepository.saveAuthorizationRequest(any(), any())).willReturn(Mono.empty()); + ServerRedirectStrategy customRedirectStrategy = (exchange, location) -> { + ServerHttpResponse response = exchange.getResponse(); + response.setStatusCode(HttpStatus.OK); + response.getHeaders().setContentType(MediaType.TEXT_PLAIN); + DataBuffer buffer = exchange.getResponse().bufferFactory() + .wrap(location.toASCIIString().getBytes(StandardCharsets.UTF_8)); + + return exchange.getResponse().writeWith(Flux.just(buffer)); + }; + this.filter.setAuthorizationRedirectStrategy(customRedirectStrategy); + this.filter.setRequestCache(this.requestCache); + + FluxExchangeResult result = this.client.get() + .uri("https://example.com/oauth2/authorization/registration-id").exchange().expectHeader() + .contentType(MediaType.TEXT_PLAIN).expectStatus().isOk().returnResult(String.class); + + // @formatter:off + StepVerifier.create(result.getResponseBody()) + .assertNext((uri) -> { + URI location = URI.create(uri); + + assertThat(location) + .hasScheme("https") + .hasHost("example.com") + .hasPath("/login/oauth/authorize") + .hasParameter("response_type", "code") + .hasParameter("client_id", "client-id") + .hasParameter("scope", "read:user") + .hasParameter("state") + .hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id"); + }) + .verifyComplete(); + // @formatter:on + + verifyNoInteractions(this.requestCache); + } + }