NIFI-9797 Corrected AccessToken.isExpired() margin calculation

Signed-off-by: Nathan Gough <thenatog@gmail.com>

This closes #5867.
This commit is contained in:
exceptionfactory 2022-03-14 16:39:47 -05:00 committed by Nathan Gough
parent 36b3f18424
commit 77c45cabc5
3 changed files with 86 additions and 50 deletions

View File

@ -17,10 +17,11 @@
package org.apache.nifi.oauth2;
import java.time.Duration;
import java.time.Instant;
public class AccessToken {
private static final int EXPIRY_MARGIN_SECONDS = 5;
private String accessToken;
private String refreshToken;
private String tokenType;
@ -29,8 +30,6 @@ public class AccessToken {
private final Instant fetchTime;
public static final int EXPIRY_MARGIN = 5000;
public AccessToken() {
this.fetchTime = Instant.now();
}
@ -89,8 +88,7 @@ public class AccessToken {
}
public boolean isExpired() {
boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative();
return expired;
final Instant expirationTime = fetchTime.plusSeconds(expiresIn).plusSeconds(EXPIRY_MARGIN_SECONDS);
return Instant.now().isAfter(expirationTime);
}
}

View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.oauth2;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class AccessTokenTest {
private static final String ACCESS_TOKEN = "ACCESS";
private static final String REFRESH_TOKEN = "REFRESH";
private static final String TOKEN_TYPE = "Bearer";
private static final String SCOPES = "default";
private static final long TWO_SECONDS_AGO = -2;
private static final long TEN_SECONDS_AGO = -10;
private static final long IN_SIXTY_SECONDS = 60;
@Test
public void testIsExpiredTenSecondsAgo() {
final AccessToken accessToken = getAccessToken(TEN_SECONDS_AGO);
assertTrue(accessToken.isExpired());
}
@Test
public void testIsExpiredTwoSecondsAgo() {
final AccessToken accessToken = getAccessToken(TWO_SECONDS_AGO);
assertFalse(accessToken.isExpired());
}
@Test
public void testIsExpiredInSixtySeconds() {
final AccessToken accessToken = getAccessToken(IN_SIXTY_SECONDS);
assertFalse(accessToken.isExpired());
}
private AccessToken getAccessToken(final long expiresInSeconds) {
return new AccessToken(
ACCESS_TOKEN,
REFRESH_TOKEN,
TOKEN_TYPE,
expiresInSeconds,
SCOPES
);
}
}

View File

@ -29,28 +29,31 @@ import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.util.NoOpProcessor;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StandardOauth2AccessTokenProviderTest {
private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl";
private static final String USERNAME = "username";
@ -75,10 +78,8 @@ public class StandardOauth2AccessTokenProviderTest {
@Captor
private ArgumentCaptor<Throwable> throwableCaptor;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
@BeforeEach
public void setUp() {
testSubject = new StandardOauth2AccessTokenProvider() {
@Override
protected OkHttpClient createHttpClient(ConfigurationContext context) {
@ -103,7 +104,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test
public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception {
// GIVEN
Processor processor = new NoOpProcessor();
TestRunner runner = TestRunners.newTestRunner(processor);
@ -111,16 +111,13 @@ public class StandardOauth2AccessTokenProviderTest {
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
// WHEN
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
// THEN
runner.assertNotValid(testSubject);
}
@Test
public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception {
// GIVEN
Processor processor = new NoOpProcessor();
TestRunner runner = TestRunners.newTestRunner(processor);
@ -128,12 +125,10 @@ public class StandardOauth2AccessTokenProviderTest {
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
// WHEN
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
// THEN
runner.assertValid(testSubject);
}
@ -141,7 +136,6 @@ public class StandardOauth2AccessTokenProviderTest {
public void testAcquireNewToken() throws Exception {
String accessTokenValue = "access_token_value";
// GIVEN
Response response = buildResponse(
200,
"{ \"access_token\":\"" + accessTokenValue + "\" }"
@ -149,22 +143,19 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
// WHEN
String actual = testSubject.getAccessDetails().getAccessToken();
// THEN
assertEquals(accessTokenValue, actual);
}
@Test
public void testRefreshToken() throws Exception {
// GIVEN
String firstToken = "first_token";
String expectedToken = "second_token";
Response response1 = buildResponse(
200,
"{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
"{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"-60\", \"refresh_token\":\"not_checking_in_this_test\" }"
);
Response response2 = buildResponse(
@ -174,17 +165,14 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails();
String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
assertEquals(expectedToken, actualToken);
}
@Test
public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
String refreshErrorMessage = "refresh_error";
String acquireErrorMessage = "acquire_error";
@ -203,16 +191,12 @@ public class StandardOauth2AccessTokenProviderTest {
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
UncheckedIOException actualException = assertThrows(
UncheckedIOException.class,
() -> testSubject.getAccessDetails()
);
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
@ -222,7 +206,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test
public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
// GIVEN
String refreshErrorMessage = "refresh_error";
String expectedToken = "expected_token";
@ -246,13 +229,9 @@ public class StandardOauth2AccessTokenProviderTest {
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
@ -262,7 +241,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test
public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }";
String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }";
@ -289,16 +267,12 @@ public class StandardOauth2AccessTokenProviderTest {
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody)
);
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
ProcessException actualException = assertThrows(
ProcessException.class,
() -> testSubject.getAccessDetails()
);
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
@ -310,7 +284,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test
public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
// GIVEN
String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }";
String expectedToken = "expected_token";
@ -335,15 +308,11 @@ public class StandardOauth2AccessTokenProviderTest {
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
List<String> expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
List<String> expectedLoggedError = Collections.singletonList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
@ -356,7 +325,7 @@ public class StandardOauth2AccessTokenProviderTest {
private Response buildSuccessfulInitResponse() {
return buildResponse(
200,
"{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
"{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"-60\", \"refresh_token\":\"not_checking_in_this_test\" }"
);
}