NIFI-9807 Added Refresh Window Property to OAuth2 Token Provider

- Removed hard-coded expiry margin from AccessToken.isExpired() determination

This closes #5876

Signed-off-by: Mike Thomsen <mthomsen@apache.org>
This commit is contained in:
exceptionfactory 2022-03-17 10:30:49 -05:00 committed by Mike Thomsen
parent fc30b649cc
commit ab0d2c2f72
No known key found for this signature in database
GPG Key ID: 88511C3D4CAD246F
4 changed files with 45 additions and 34 deletions

View File

@ -20,8 +20,6 @@ package org.apache.nifi.oauth2;
import java.time.Instant;
public class AccessToken {
private static final int EXPIRY_MARGIN_SECONDS = 300;
private String accessToken;
private String refreshToken;
private String tokenType;
@ -31,7 +29,7 @@ public class AccessToken {
private final Instant fetchTime;
public AccessToken() {
this.fetchTime = now();
this.fetchTime = Instant.now();
}
public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) {
@ -88,11 +86,8 @@ public class AccessToken {
}
public boolean isExpired() {
final Instant expirationTime = fetchTime.plusSeconds(expiresIn).minusSeconds(EXPIRY_MARGIN_SECONDS);
boolean expired = now().isAfter(expirationTime);
return expired;
final Instant expirationTime = fetchTime.plusSeconds(expiresIn);
return now().isAfter(expirationTime);
}
Instant now() {

View File

@ -25,6 +25,10 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class AccessTokenTest {
private static final long FIVE_MINUTES_AGO = -300;
private static final long IN_FIVE_MINUTES = 300;
private Instant now;
@BeforeEach
@ -33,22 +37,15 @@ public class AccessTokenTest {
}
@Test
public void testIsExpiredInLessThan5Minutes() {
final AccessToken accessToken = getAccessToken(299);
public void testIsExpiredFiveMinutesAgo() {
final AccessToken accessToken = getAccessToken(FIVE_MINUTES_AGO);
assertTrue(accessToken.isExpired());
}
@Test
public void testIsExpiredInExactly5Minutes() {
final AccessToken accessToken = getAccessToken(300);
assertFalse(accessToken.isExpired());
}
@Test
public void testIsExpiredInMoreThan5Minutes() {
final AccessToken accessToken = getAccessToken(301);
public void testIsExpiredInFiveMinutes() {
final AccessToken accessToken = getAccessToken(IN_FIVE_MINUTES);
assertFalse(accessToken.isExpired());
}

View File

@ -44,11 +44,13 @@ import javax.net.ssl.SSLContext;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
@Tags({"oauth2", "provider", "authorization", "access token", "http"})
@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
@ -120,9 +122,18 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
.build();
public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
.name("refresh-window")
.displayName("Refresh Window")
.description("The service will attempt to refresh tokens expiring within the refresh window, subtracting the configured duration from the token expiration.")
.addValidator(StandardValidators.TIME_PERIOD_VALIDATOR)
.defaultValue("0 s")
.required(true)
.build();
public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
.name("ssl-context-service")
.displayName("SSL Context Servuce")
.displayName("SSL Context Service")
.addValidator(Validator.VALID)
.identifiesControllerService(SSLContextService.class)
.required(false)
@ -135,6 +146,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
PASSWORD,
CLIENT_ID,
CLIENT_SECRET,
REFRESH_WINDOW,
SSL_CONTEXT
));
@ -150,6 +162,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
private volatile String password;
private volatile String clientId;
private volatile String clientSecret;
private volatile long refreshWindowSeconds;
private volatile AccessToken accessDetails;
@ -169,6 +182,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
password = context.getProperty(PASSWORD).getValue();
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
refreshWindowSeconds = context.getProperty(REFRESH_WINDOW).asTimePeriod(TimeUnit.SECONDS);
}
@OnDisabled
@ -215,14 +230,14 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
public AccessToken getAccessDetails() {
if (this.accessDetails == null) {
acquireAccessDetails();
} else if (this.accessDetails.isExpired()) {
} else if (isRefreshRequired()) {
if (this.accessDetails.getRefreshToken() == null) {
acquireAccessDetails();
} else {
try {
refreshAccessDetails();
} catch (Exception e) {
getLogger().info("Couldn't refresh access token", e);
getLogger().info("Refresh Access Token request failed [{}]", authorizationServerUrl, e);
acquireAccessDetails();
}
}
@ -232,7 +247,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
}
private void acquireAccessDetails() {
getLogger().debug("Getting a new access token");
getLogger().debug("New Access Token request started [{}]", authorizationServerUrl);
FormBody.Builder acquireTokenBuilder = new FormBody.Builder();
@ -260,7 +275,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
}
private void refreshAccessDetails() {
getLogger().debug("Refreshing access token");
getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl);
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
.add("grant_type", "refresh_token")
@ -297,4 +312,12 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
throw new UncheckedIOException("OAuth2 access token request failed", e);
}
}
private boolean isRefreshRequired() {
final Instant expirationRefreshTime = accessDetails.getFetchTime()
.plusSeconds(accessDetails.getExpiresIn())
.minusSeconds(refreshWindowSeconds);
return Instant.now().isAfter(expirationRefreshTime);
}
}

View File

@ -43,11 +43,13 @@ import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -60,6 +62,7 @@ public class StandardOauth2AccessTokenProviderTest {
private static final String PASSWORD = "password";
private static final String CLIENT_ID = "clientId";
private static final String CLIENT_SECRET = "clientSecret";
private static final long FIVE_MINUTES = 300;
private StandardOauth2AccessTokenProvider testSubject;
@ -72,8 +75,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Mock
private ComponentLog mockLogger;
@Captor
private ArgumentCaptor<String> debugCaptor;
@Captor
private ArgumentCaptor<String> errorCaptor;
@Captor
private ArgumentCaptor<Throwable> throwableCaptor;
@ -98,6 +99,7 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
testSubject.onEnabled(mockContext);
}
@ -378,13 +380,7 @@ public class StandardOauth2AccessTokenProviderTest {
}
private void checkLoggedDebugWhenRefreshFails() {
verify(mockLogger, times(3)).debug(debugCaptor.capture());
List<String> actualDebugMessages = debugCaptor.getAllValues();
assertEquals(
Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"),
actualDebugMessages
);
verify(mockLogger, times(3)).debug(anyString(), eq(AUTHORIZATION_SERVER_URL));
}
private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List<String> expectedLoggedError) {
@ -395,7 +391,7 @@ public class StandardOauth2AccessTokenProviderTest {
}
private void checkLoggedRefreshError(Throwable expectedRefreshError) {
verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture());
verify(mockLogger).info(anyString(), eq(AUTHORIZATION_SERVER_URL), throwableCaptor.capture());
Throwable actualRefreshError = throwableCaptor.getValue();
checkError(expectedRefreshError, actualRefreshError);