When expired retrieve new Client Credentials token.

Once client credentials access token has expired retrieve a new token from the OAuth2 authorization server.
These tokens can't be refreshed because they do not have a refresh token associated with. This is standard behaviour for Oauth 2 client credentails

Fixes gh-5893
This commit is contained in:
Warren Bailey 2018-11-16 16:39:29 +00:00 committed by Joe Grandja
parent 9b65107922
commit 1c9ab9197e
5 changed files with 204 additions and 5 deletions

View File

@ -133,7 +133,7 @@ class OAuth2AuthorizedClientResolver {
});
}
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
Mono<OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)

View File

@ -84,8 +84,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository));
}
ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) {
this.authorizedClientRepository = authorizedClientRepository;
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
this.authorizedClientResolver = authorizedClientResolver;
}
/**
@ -245,13 +249,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
}
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
return createRequest(request)
.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
} else if (shouldRefresh(authorizedClient)) {
return createRequest(request)
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
}
return Mono.just(authorizedClient);
}
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
}
private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();
return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
ServerWebExchange exchange = r.getExchange();
@ -280,6 +301,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

View File

@ -332,12 +332,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
if (clientRegistration == null) {
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
}
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
if (isClientCredentialsGrantType(clientRegistration)) {
return getAuthorizedClient(clientRegistration, attrs);
}
throw new ClientAuthorizationRequiredException(clientRegistrationId);
}
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
}
private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
Map<String, Object> attrs) {
@ -366,7 +370,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
}
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
//Client credentials grant do not have refresh tokens but can expire so we need to get another one
return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes()));
} else if (shouldRefresh(authorizedClient)) {
return refreshAuthorizedClient(request, next, authorizedClient);
}
return Mono.just(authorizedClient);
@ -407,6 +415,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

View File

@ -42,6 +42,7 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@ -67,6 +68,7 @@ import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
@ -86,6 +88,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock
private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver;
@Mock
private ServerWebExchange serverWebExchange;
@ -144,6 +149,88 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
}
@Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
String clientRegistrationId = registration.getClientId();
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"new-token",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", newAccessToken, null);
Request r = new Request(clientRegistrationId, authentication, null);
when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();
this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any());
verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}
@Test
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();
this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any());
verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}
@Test
public void filterWhenRefreshRequiredThenRefresh() {
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());

View File

@ -78,6 +78,7 @@ import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
@ -423,6 +424,80 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request1)).isEmpty();
}
@Test
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
this.registration = TestClientRegistrations.clientCredentials().build();
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
this.authorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(authentication(this.authentication))
.build();
this.function.filter(request, this.exchange).block();
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}
@Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
this.registration = TestClientRegistrations.clientCredentials().build();
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse().build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
accessTokenResponse);
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
this.authorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(authentication(this.authentication))
.build();
this.function.filter(request, this.exchange).block();
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
verify(clientCredentialsTokenResponseClient).getTokenResponse(any());
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}
@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")