NIFI-9797 Corrected AccessToken.isExpired() margin calculation

Signed-off-by: Nathan Gough <thenatog@gmail.com>

This closes #5867.
This commit is contained in:
exceptionfactory 2022-03-14 16:39:47 -05:00 committed by Nathan Gough
parent 36b3f18424
commit 77c45cabc5
3 changed files with 86 additions and 50 deletions

View File

@ -17,10 +17,11 @@
package org.apache.nifi.oauth2; package org.apache.nifi.oauth2;
import java.time.Duration;
import java.time.Instant; import java.time.Instant;
public class AccessToken { public class AccessToken {
private static final int EXPIRY_MARGIN_SECONDS = 5;
private String accessToken; private String accessToken;
private String refreshToken; private String refreshToken;
private String tokenType; private String tokenType;
@ -29,8 +30,6 @@ public class AccessToken {
private final Instant fetchTime; private final Instant fetchTime;
public static final int EXPIRY_MARGIN = 5000;
public AccessToken() { public AccessToken() {
this.fetchTime = Instant.now(); this.fetchTime = Instant.now();
} }
@ -89,8 +88,7 @@ public class AccessToken {
} }
public boolean isExpired() { public boolean isExpired() {
boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative(); final Instant expirationTime = fetchTime.plusSeconds(expiresIn).plusSeconds(EXPIRY_MARGIN_SECONDS);
return Instant.now().isAfter(expirationTime);
return expired;
} }
} }

View File

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.oauth2;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class AccessTokenTest {
private static final String ACCESS_TOKEN = "ACCESS";
private static final String REFRESH_TOKEN = "REFRESH";
private static final String TOKEN_TYPE = "Bearer";
private static final String SCOPES = "default";
private static final long TWO_SECONDS_AGO = -2;
private static final long TEN_SECONDS_AGO = -10;
private static final long IN_SIXTY_SECONDS = 60;
@Test
public void testIsExpiredTenSecondsAgo() {
final AccessToken accessToken = getAccessToken(TEN_SECONDS_AGO);
assertTrue(accessToken.isExpired());
}
@Test
public void testIsExpiredTwoSecondsAgo() {
final AccessToken accessToken = getAccessToken(TWO_SECONDS_AGO);
assertFalse(accessToken.isExpired());
}
@Test
public void testIsExpiredInSixtySeconds() {
final AccessToken accessToken = getAccessToken(IN_SIXTY_SECONDS);
assertFalse(accessToken.isExpired());
}
private AccessToken getAccessToken(final long expiresInSeconds) {
return new AccessToken(
ACCESS_TOKEN,
REFRESH_TOKEN,
TOKEN_TYPE,
expiresInSeconds,
SCOPES
);
}
}

View File

@ -29,28 +29,31 @@ import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.util.NoOpProcessor; import org.apache.nifi.util.NoOpProcessor;
import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners; import org.apache.nifi.util.TestRunners;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers; import org.mockito.Answers;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StandardOauth2AccessTokenProviderTest { public class StandardOauth2AccessTokenProviderTest {
private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl"; private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl";
private static final String USERNAME = "username"; private static final String USERNAME = "username";
@ -75,10 +78,8 @@ public class StandardOauth2AccessTokenProviderTest {
@Captor @Captor
private ArgumentCaptor<Throwable> throwableCaptor; private ArgumentCaptor<Throwable> throwableCaptor;
@Before @BeforeEach
public void setUp() throws Exception { public void setUp() {
MockitoAnnotations.initMocks(this);
testSubject = new StandardOauth2AccessTokenProvider() { testSubject = new StandardOauth2AccessTokenProvider() {
@Override @Override
protected OkHttpClient createHttpClient(ConfigurationContext context) { protected OkHttpClient createHttpClient(ConfigurationContext context) {
@ -103,7 +104,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test @Test
public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception { public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception {
// GIVEN
Processor processor = new NoOpProcessor(); Processor processor = new NoOpProcessor();
TestRunner runner = TestRunners.newTestRunner(processor); TestRunner runner = TestRunners.newTestRunner(processor);
@ -111,16 +111,13 @@ public class StandardOauth2AccessTokenProviderTest {
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant"); runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
// WHEN
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
// THEN
runner.assertNotValid(testSubject); runner.assertNotValid(testSubject);
} }
@Test @Test
public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception { public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception {
// GIVEN
Processor processor = new NoOpProcessor(); Processor processor = new NoOpProcessor();
TestRunner runner = TestRunners.newTestRunner(processor); TestRunner runner = TestRunners.newTestRunner(processor);
@ -128,12 +125,10 @@ public class StandardOauth2AccessTokenProviderTest {
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant"); runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
// WHEN
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE); runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId"); runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret"); runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
// THEN
runner.assertValid(testSubject); runner.assertValid(testSubject);
} }
@ -141,7 +136,6 @@ public class StandardOauth2AccessTokenProviderTest {
public void testAcquireNewToken() throws Exception { public void testAcquireNewToken() throws Exception {
String accessTokenValue = "access_token_value"; String accessTokenValue = "access_token_value";
// GIVEN
Response response = buildResponse( Response response = buildResponse(
200, 200,
"{ \"access_token\":\"" + accessTokenValue + "\" }" "{ \"access_token\":\"" + accessTokenValue + "\" }"
@ -149,22 +143,19 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response); when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
// WHEN
String actual = testSubject.getAccessDetails().getAccessToken(); String actual = testSubject.getAccessDetails().getAccessToken();
// THEN
assertEquals(accessTokenValue, actual); assertEquals(accessTokenValue, actual);
} }
@Test @Test
public void testRefreshToken() throws Exception { public void testRefreshToken() throws Exception {
// GIVEN
String firstToken = "first_token"; String firstToken = "first_token";
String expectedToken = "second_token"; String expectedToken = "second_token";
Response response1 = buildResponse( Response response1 = buildResponse(
200, 200,
"{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }" "{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"-60\", \"refresh_token\":\"not_checking_in_this_test\" }"
); );
Response response2 = buildResponse( Response response2 = buildResponse(
@ -174,17 +165,14 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2); when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails(); testSubject.getAccessDetails();
String actualToken = testSubject.getAccessDetails().getAccessToken(); String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
assertEquals(expectedToken, actualToken); assertEquals(expectedToken, actualToken);
} }
@Test @Test
public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception { public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
String refreshErrorMessage = "refresh_error"; String refreshErrorMessage = "refresh_error";
String acquireErrorMessage = "acquire_error"; String acquireErrorMessage = "acquire_error";
@ -203,16 +191,12 @@ public class StandardOauth2AccessTokenProviderTest {
throw new IllegalStateException("Test improperly defined mock HTTP responses."); throw new IllegalStateException("Test improperly defined mock HTTP responses.");
}); });
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails(); testSubject.getAccessDetails();
// WHEN
UncheckedIOException actualException = assertThrows( UncheckedIOException actualException = assertThrows(
UncheckedIOException.class, UncheckedIOException.class,
() -> testSubject.getAccessDetails() () -> testSubject.getAccessDetails()
); );
// THEN
checkLoggedDebugWhenRefreshFails(); checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage))); checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
@ -222,7 +206,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test @Test
public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception { public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
// GIVEN
String refreshErrorMessage = "refresh_error"; String refreshErrorMessage = "refresh_error";
String expectedToken = "expected_token"; String expectedToken = "expected_token";
@ -246,13 +229,9 @@ public class StandardOauth2AccessTokenProviderTest {
throw new IllegalStateException("Test improperly defined mock HTTP responses."); throw new IllegalStateException("Test improperly defined mock HTTP responses.");
}); });
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails(); testSubject.getAccessDetails();
// WHEN
String actualToken = testSubject.getAccessDetails().getAccessToken(); String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
checkLoggedDebugWhenRefreshFails(); checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage))); checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
@ -262,7 +241,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test @Test
public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception { public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }"; String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }";
String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }"; String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }";
@ -289,16 +267,12 @@ public class StandardOauth2AccessTokenProviderTest {
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody) String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody)
); );
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails(); testSubject.getAccessDetails();
// WHEN
ProcessException actualException = assertThrows( ProcessException actualException = assertThrows(
ProcessException.class, ProcessException.class,
() -> testSubject.getAccessDetails() () -> testSubject.getAccessDetails()
); );
// THEN
checkLoggedDebugWhenRefreshFails(); checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]")); checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
@ -310,7 +284,6 @@ public class StandardOauth2AccessTokenProviderTest {
@Test @Test
public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception { public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
// GIVEN
String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }"; String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }";
String expectedToken = "expected_token"; String expectedToken = "expected_token";
@ -335,15 +308,11 @@ public class StandardOauth2AccessTokenProviderTest {
throw new IllegalStateException("Test improperly defined mock HTTP responses."); throw new IllegalStateException("Test improperly defined mock HTTP responses.");
}); });
List<String> expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse)); List<String> expectedLoggedError = Collections.singletonList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails(); testSubject.getAccessDetails();
// WHEN
String actualToken = testSubject.getAccessDetails().getAccessToken(); String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
checkLoggedDebugWhenRefreshFails(); checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]")); checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
@ -356,7 +325,7 @@ public class StandardOauth2AccessTokenProviderTest {
private Response buildSuccessfulInitResponse() { private Response buildSuccessfulInitResponse() {
return buildResponse( return buildResponse(
200, 200,
"{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }" "{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"-60\", \"refresh_token\":\"not_checking_in_this_test\" }"
); );
} }