From 1c9ab9197edfeae3d1ba6676fe1c77ff9b25a1d6 Mon Sep 17 00:00:00 2001 From: Warren Bailey Date: Fri, 16 Nov 2018 16:39:29 +0000 Subject: [PATCH] When expired retrieve new Client Credentials token. Once client credentials access token has expired retrieve a new token from the OAuth2 authorization server. These tokens can't be refreshed because they do not have a refresh token associated with. This is standard behaviour for Oauth 2 client credentails Fixes gh-5893 --- .../OAuth2AuthorizedClientResolver.java | 2 +- ...uthorizedClientExchangeFilterFunction.java | 29 ++++++- ...uthorizedClientExchangeFilterFunction.java | 16 +++- ...izedClientExchangeFilterFunctionTests.java | 87 +++++++++++++++++++ ...izedClientExchangeFilterFunctionTests.java | 75 ++++++++++++++++ 5 files changed, 204 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java index df1a365566..2180562190 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java @@ -133,7 +133,7 @@ class OAuth2AuthorizedClientResolver { }); } - private Mono clientCredentials( + Mono clientCredentials( ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) { OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index ef92036690..548c918014 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -84,8 +84,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private final OAuth2AuthorizedClientResolver authorizedClientResolver; public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository)); + } + + ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) { this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); + this.authorizedClientResolver = authorizedClientResolver; } /** @@ -245,13 +249,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements } private Mono refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { - if (shouldRefresh(authorizedClient)) { + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { + return createRequest(request) + .flatMap(r -> authorizeWithClientCredentials(clientRegistration, r)); + } else if (shouldRefresh(authorizedClient)) { return createRequest(request) .flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r)); } return Mono.just(authorizedClient); } + private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { + return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); + } + + private Mono authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) { + Authentication authentication = request.getAuthentication(); + ServerWebExchange exchange = request.getExchange(); + + return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange). + flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange) + .thenReturn(result)); + } + private Mono refreshAuthorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) { ServerWebExchange exchange = r.getExchange(); @@ -280,6 +301,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements if (refreshToken == null) { return false; } + return hasTokenExpired(authorizedClient); + } + + private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { Instant now = this.clock.instant(); Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 6914363aa5..40b244fecf 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -332,12 +332,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (clientRegistration == null) { throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); } - if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + if (isClientCredentialsGrantType(clientRegistration)) { return getAuthorizedClient(clientRegistration, attrs); } throw new ClientAuthorizationRequiredException(clientRegistrationId); } + private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { + return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); + } + private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, Map attrs) { @@ -366,7 +370,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement } private Mono authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { - if (shouldRefresh(authorizedClient)) { + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { + //Client credentials grant do not have refresh tokens but can expire so we need to get another one + return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes())); + } else if (shouldRefresh(authorizedClient)) { return refreshAuthorizedClient(request, next, authorizedClient); } return Mono.just(authorizedClient); @@ -407,6 +415,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (refreshToken == null) { return false; } + return hasTokenExpired(authorizedClient); + } + + private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { Instant now = this.clock.instant(); Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 97d17d0fb1..1f20a75cbb 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -42,6 +42,7 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -67,6 +68,7 @@ import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -86,6 +88,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; + @Mock + private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver; + @Mock private ServerWebExchange serverWebExchange; @@ -144,6 +149,88 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } + @Test + public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); + String clientRegistrationId = registration.getClientId(); + + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver); + + OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "new-token", + Instant.now(), + Instant.now().plus(Duration.ofDays(1))); + OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration, + "principalName", newAccessToken, null); + Request r = new Request(clientRegistrationId, authentication, null); + when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient)); + when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r)); + + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + + OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, + "principalName", accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); + verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any()); + verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); + + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, + "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any()); + verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + @Test public void filterWhenRefreshRequiredThenRefresh() { when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 1431864f85..e6dfa7a1e9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -78,6 +78,7 @@ import static org.assertj.core.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -423,6 +424,80 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request1)).isEmpty(); } + @Test + public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { + this.registration = TestClientRegistrations.clientCredentials().build(); + + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); + this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); + + verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { + this.registration = TestClientRegistrations.clientCredentials().build(); + + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse().build(); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( + accessTokenResponse); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); + this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); + + verify(clientCredentialsTokenResponseClient).getTokenResponse(any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")