JwtBearerOAuth2AuthorizedClientProvider checks for access token expiry

Fixes gh-9700
This commit is contained in:
Joe Grandja 2021-04-30 09:21:06 -04:00
parent fc6fa79c86
commit 761e3a9dd8
2 changed files with 131 additions and 8 deletions

View File

@ -16,6 +16,10 @@
package org.springframework.security.oauth2.client; package org.springframework.security.oauth2.client;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest; import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
@ -23,6 +27,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResp
import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@ -40,12 +45,18 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient(); private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient();
private Duration clockSkew = Duration.ofSeconds(60);
private Clock clock = Clock.systemUTC();
/** /**
* Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() * Attempt to authorize (or re-authorize) the
* client} in the provided {@code context}. Returns {@code null} if authorization is * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided
* not supported, e.g. the client's * {@code context}. Returns {@code null} if authorization (or re-authorization) is not
* {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is * supported, e.g. the client's {@link ClientRegistration#getAuthorizationGrantType()
* not {@link AuthorizationGrantType#JWT_BEARER jwt-bearer}. * authorization grant type} is not {@link AuthorizationGrantType#JWT_BEARER
* jwt-bearer} OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is
* not expired.
* @param context the context that holds authorization-specific state for the client * @param context the context that holds authorization-specific state for the client
* @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not
* supported * supported
@ -59,8 +70,9 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
return null; return null;
} }
OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient();
if (authorizedClient != null) { if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) {
// Client is already authorized // If client is already authorized but access token is NOT expired than no
// need for re-authorization
return null; return null;
} }
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) { if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
@ -95,6 +107,10 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
} }
} }
private boolean hasTokenExpired(OAuth2Token token) {
return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew));
}
/** /**
* Sets the client used when requesting an access token credential at the Token * Sets the client used when requesting an access token credential at the Token
* Endpoint for the {@code jwt-bearer} grant. * Endpoint for the {@code jwt-bearer} grant.
@ -107,4 +123,31 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
this.accessTokenResponseClient = accessTokenResponseClient; this.accessTokenResponseClient = accessTokenResponseClient;
} }
/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
* 60 seconds.
*
* <p>
* An access token is considered expired if
* {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time
* {@code clock#instant()}.
* @param clockSkew the maximum acceptable clock skew
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}
/**
* Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access
* token expiry.
* @param clock the clock
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}
} }

View File

@ -16,6 +16,9 @@
package org.springframework.security.oauth2.client; package org.springframework.security.oauth2.client;
import java.time.Duration;
import java.time.Instant;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -27,6 +30,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
@ -83,6 +87,33 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
.withMessage("accessTokenResponseClient cannot be null"); .withMessage("accessTokenResponseClient cannot be null");
} }
@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null");
// @formatter:on
}
@Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0");
// @formatter:on
}
@Test
public void setClockWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null");
// @formatter:on
}
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
// @formatter:off // @formatter:off
@ -105,7 +136,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
} }
@Test @Test
public void authorizeWhenJwtBearerAndAuthorizedThenNotAuthorized() { public void authorizeWhenJwtBearerAndTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write"));
// @formatter:off // @formatter:off
@ -117,6 +148,55 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@Test
public void authorizeWhenJwtBearerAndTokenExpiredThenReauthorize() {
Instant now = Instant.now();
Instant issuedAt = now.minus(Duration.ofMinutes(60));
Instant expiresAt = now.minus(Duration.ofMinutes(30));
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234",
issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
@Test
public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorize() {
Instant now = Instant.now();
Instant issuedAt = now.minus(Duration.ofMinutes(60));
Instant expiresAt = now.plus(Duration.ofMinutes(1));
OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token-1234", issuedAt, expiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), expiresInOneMinAccessToken);
// Shorten the lifespan of the access token by 90 seconds, which will ultimately
// force it to expire on the client
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
}
@Test @Test
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() { public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
// @formatter:off // @formatter:off