diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 65d6c26c4b..9c8e088458 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -26,10 +26,15 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -107,16 +112,35 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = + new DefaultClientCredentialsTokenResponseClient(); + private boolean defaultOAuth2AuthorizedClient; public ServletOAuth2AuthorizedClientExchangeFilterFunction() {} - public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) { + public ServletOAuth2AuthorizedClientExchangeFilterFunction( + ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; } + /** + * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for + * client_credentials grant. + * @param clientCredentialsTokenResponseClient the client to use + */ + public void setClientCredentialsTokenResponseClient( + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } + /** * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be @@ -277,18 +301,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); } if (clientRegistrationId != null) { - HttpServletRequest request = (HttpServletRequest) attrs.get( - HTTP_SERVLET_REQUEST_ATTR_NAME); + HttpServletRequest request = getRequest(attrs); OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository .loadAuthorizedClient(clientRegistrationId, authentication, request); if (authorizedClient == null) { - throw new ClientAuthorizationRequiredException(clientRegistrationId); + authorizedClient = getAuthorizedClient(clientRegistrationId, attrs); } oauth2AuthorizedClient(authorizedClient).accept(attrs); } } + private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map attrs) { + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + if (clientRegistration == null) { + throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); + } + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return getAuthorizedClient(clientRegistration, attrs); + } + throw new ClientAuthorizationRequiredException(clientRegistrationId); + } + + + private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, + Map attrs) { + + HttpServletRequest request = getRequest(attrs); + HttpServletResponse response = getResponse(attrs); + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(clientRegistration); + OAuth2AccessTokenResponse tokenResponse = + this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + Authentication principal = getAuthentication(attrs); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, + (principal != null ? principal.getName() : "anonymousUser"), + tokenResponse.getAccessToken()); + + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + principal, + request, + response); + + return authorizedClient; + } + private Mono authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { if (shouldRefresh(authorizedClient)) { return refreshAuthorizedClient(request, next, authorizedClient); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java index d7e4fc6c5a..4764723c55 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java @@ -54,4 +54,11 @@ public class TestClientRegistrations { .clientId("client-id-2") .clientSecret("client-secret"); } + + public static ClientRegistration.Builder clientCredentials() { + return clientRegistration() + .registrationId("client-credentials") + .clientId("client-id") + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 4e448d2cca..c98d549c15 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -46,12 +46,16 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private OAuth2AuthorizedClientRepository authorizedClientRepository; @Mock + private ClientRegistrationRepository clientRegistrationRepository; + @Mock + private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; + @Mock private WebClient.RequestHeadersSpec spec; @Captor private ArgumentCaptor>> attrs; @@ -148,7 +156,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); assertThat(getAuthentication(attrs)).isEqualTo(this.authentication); @@ -157,7 +166,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); oauth2AuthorizedClient(authorizedClient).accept(this.result); @@ -168,7 +178,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); Map attrs = getDefaultRequestAttributes(); assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); verifyZeroInteractions(this.authorizedClientRepository); @@ -176,7 +187,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); Map attrs = getDefaultRequestAttributes(); assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); verifyZeroInteractions(this.authorizedClientRepository); @@ -196,7 +208,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); this.function.setDefaultOAuth2AuthorizedClient(true); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); @@ -214,7 +227,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); @@ -227,7 +241,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); @@ -245,9 +260,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); - OAuth2User user = mock(OAuth2User.class); - List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); @@ -259,6 +273,41 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any()); } + @Test + public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { + this.registration = TestClientRegistrations.clientCredentials().build(); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); + this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse().build(); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( + accessTokenResponse); + + clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); + + Map attrs = getDefaultRequestAttributes(); + OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); + + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + } + + @Test + public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() { + this.registration = TestClientRegistrations.clientCredentials().build(); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); + + clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); + + assertThatCode(() -> getDefaultRequestAttributes()) + .isInstanceOf(IllegalArgumentException.class); + } + private Map getDefaultRequestAttributes() { this.function.defaultRequest().accept(this.spec); verify(this.spec).attributes(this.attrs.capture()); @@ -322,7 +371,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -368,7 +418,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -400,7 +451,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); @@ -422,7 +474,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,