mirror of
https://github.com/apache/nifi.git
synced 2025-02-07 02:28:31 +00:00
NIFI-9807 Added Refresh Window Property to OAuth2 Token Provider
- Removed hard-coded expiry margin from AccessToken.isExpired() determination This closes #5876 Signed-off-by: Mike Thomsen <mthomsen@apache.org>
This commit is contained in:
parent
fc30b649cc
commit
ab0d2c2f72
@ -20,8 +20,6 @@ package org.apache.nifi.oauth2;
|
||||
import java.time.Instant;
|
||||
|
||||
public class AccessToken {
|
||||
private static final int EXPIRY_MARGIN_SECONDS = 300;
|
||||
|
||||
private String accessToken;
|
||||
private String refreshToken;
|
||||
private String tokenType;
|
||||
@ -31,7 +29,7 @@ public class AccessToken {
|
||||
private final Instant fetchTime;
|
||||
|
||||
public AccessToken() {
|
||||
this.fetchTime = now();
|
||||
this.fetchTime = Instant.now();
|
||||
}
|
||||
|
||||
public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) {
|
||||
@ -88,11 +86,8 @@ public class AccessToken {
|
||||
}
|
||||
|
||||
public boolean isExpired() {
|
||||
final Instant expirationTime = fetchTime.plusSeconds(expiresIn).minusSeconds(EXPIRY_MARGIN_SECONDS);
|
||||
|
||||
boolean expired = now().isAfter(expirationTime);
|
||||
|
||||
return expired;
|
||||
final Instant expirationTime = fetchTime.plusSeconds(expiresIn);
|
||||
return now().isAfter(expirationTime);
|
||||
}
|
||||
|
||||
Instant now() {
|
||||
|
@ -25,6 +25,10 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class AccessTokenTest {
|
||||
private static final long FIVE_MINUTES_AGO = -300;
|
||||
|
||||
private static final long IN_FIVE_MINUTES = 300;
|
||||
|
||||
private Instant now;
|
||||
|
||||
@BeforeEach
|
||||
@ -33,22 +37,15 @@ public class AccessTokenTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsExpiredInLessThan5Minutes() {
|
||||
final AccessToken accessToken = getAccessToken(299);
|
||||
public void testIsExpiredFiveMinutesAgo() {
|
||||
final AccessToken accessToken = getAccessToken(FIVE_MINUTES_AGO);
|
||||
|
||||
assertTrue(accessToken.isExpired());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsExpiredInExactly5Minutes() {
|
||||
final AccessToken accessToken = getAccessToken(300);
|
||||
|
||||
assertFalse(accessToken.isExpired());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIsExpiredInMoreThan5Minutes() {
|
||||
final AccessToken accessToken = getAccessToken(301);
|
||||
public void testIsExpiredInFiveMinutes() {
|
||||
final AccessToken accessToken = getAccessToken(IN_FIVE_MINUTES);
|
||||
|
||||
assertFalse(accessToken.isExpired());
|
||||
}
|
||||
|
@ -44,11 +44,13 @@ import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Tags({"oauth2", "provider", "authorization", "access token", "http"})
|
||||
@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
|
||||
@ -120,9 +122,18 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
|
||||
.name("refresh-window")
|
||||
.displayName("Refresh Window")
|
||||
.description("The service will attempt to refresh tokens expiring within the refresh window, subtracting the configured duration from the token expiration.")
|
||||
.addValidator(StandardValidators.TIME_PERIOD_VALIDATOR)
|
||||
.defaultValue("0 s")
|
||||
.required(true)
|
||||
.build();
|
||||
|
||||
public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
|
||||
.name("ssl-context-service")
|
||||
.displayName("SSL Context Servuce")
|
||||
.displayName("SSL Context Service")
|
||||
.addValidator(Validator.VALID)
|
||||
.identifiesControllerService(SSLContextService.class)
|
||||
.required(false)
|
||||
@ -135,6 +146,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
PASSWORD,
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
REFRESH_WINDOW,
|
||||
SSL_CONTEXT
|
||||
));
|
||||
|
||||
@ -150,6 +162,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
private volatile String password;
|
||||
private volatile String clientId;
|
||||
private volatile String clientSecret;
|
||||
private volatile long refreshWindowSeconds;
|
||||
|
||||
private volatile AccessToken accessDetails;
|
||||
|
||||
@ -169,6 +182,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
password = context.getProperty(PASSWORD).getValue();
|
||||
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
|
||||
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
|
||||
|
||||
refreshWindowSeconds = context.getProperty(REFRESH_WINDOW).asTimePeriod(TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@OnDisabled
|
||||
@ -215,14 +230,14 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
public AccessToken getAccessDetails() {
|
||||
if (this.accessDetails == null) {
|
||||
acquireAccessDetails();
|
||||
} else if (this.accessDetails.isExpired()) {
|
||||
} else if (isRefreshRequired()) {
|
||||
if (this.accessDetails.getRefreshToken() == null) {
|
||||
acquireAccessDetails();
|
||||
} else {
|
||||
try {
|
||||
refreshAccessDetails();
|
||||
} catch (Exception e) {
|
||||
getLogger().info("Couldn't refresh access token", e);
|
||||
getLogger().info("Refresh Access Token request failed [{}]", authorizationServerUrl, e);
|
||||
acquireAccessDetails();
|
||||
}
|
||||
}
|
||||
@ -232,7 +247,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
}
|
||||
|
||||
private void acquireAccessDetails() {
|
||||
getLogger().debug("Getting a new access token");
|
||||
getLogger().debug("New Access Token request started [{}]", authorizationServerUrl);
|
||||
|
||||
FormBody.Builder acquireTokenBuilder = new FormBody.Builder();
|
||||
|
||||
@ -260,7 +275,7 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
}
|
||||
|
||||
private void refreshAccessDetails() {
|
||||
getLogger().debug("Refreshing access token");
|
||||
getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl);
|
||||
|
||||
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
|
||||
.add("grant_type", "refresh_token")
|
||||
@ -297,4 +312,12 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||
throw new UncheckedIOException("OAuth2 access token request failed", e);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isRefreshRequired() {
|
||||
final Instant expirationRefreshTime = accessDetails.getFetchTime()
|
||||
.plusSeconds(accessDetails.getExpiresIn())
|
||||
.minusSeconds(refreshWindowSeconds);
|
||||
|
||||
return Instant.now().isAfter(expirationRefreshTime);
|
||||
}
|
||||
}
|
||||
|
@ -43,11 +43,13 @@ import java.io.UncheckedIOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
@ -60,6 +62,7 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||
private static final String PASSWORD = "password";
|
||||
private static final String CLIENT_ID = "clientId";
|
||||
private static final String CLIENT_SECRET = "clientSecret";
|
||||
private static final long FIVE_MINUTES = 300;
|
||||
|
||||
private StandardOauth2AccessTokenProvider testSubject;
|
||||
|
||||
@ -72,8 +75,6 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||
@Mock
|
||||
private ComponentLog mockLogger;
|
||||
@Captor
|
||||
private ArgumentCaptor<String> debugCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<String> errorCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<Throwable> throwableCaptor;
|
||||
@ -98,6 +99,7 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
|
||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
|
||||
|
||||
testSubject.onEnabled(mockContext);
|
||||
}
|
||||
@ -378,13 +380,7 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||
}
|
||||
|
||||
private void checkLoggedDebugWhenRefreshFails() {
|
||||
verify(mockLogger, times(3)).debug(debugCaptor.capture());
|
||||
List<String> actualDebugMessages = debugCaptor.getAllValues();
|
||||
|
||||
assertEquals(
|
||||
Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"),
|
||||
actualDebugMessages
|
||||
);
|
||||
verify(mockLogger, times(3)).debug(anyString(), eq(AUTHORIZATION_SERVER_URL));
|
||||
}
|
||||
|
||||
private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List<String> expectedLoggedError) {
|
||||
@ -395,7 +391,7 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||
}
|
||||
|
||||
private void checkLoggedRefreshError(Throwable expectedRefreshError) {
|
||||
verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture());
|
||||
verify(mockLogger).info(anyString(), eq(AUTHORIZATION_SERVER_URL), throwableCaptor.capture());
|
||||
Throwable actualRefreshError = throwableCaptor.getValue();
|
||||
|
||||
checkError(expectedRefreshError, actualRefreshError);
|
||||
|
Loading…
x
Reference in New Issue
Block a user