Hadoop 16890. Change in expiry calculation for MSI token provider.

Contributed by Bilahari T H
This commit is contained in:
bilaharith 2020-03-12 02:09:10 +05:30 committed by GitHub
parent cf9cf83a43
commit 0b931f36ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 177 additions and 15 deletions

View File

@ -283,6 +283,12 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-library</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -72,7 +72,7 @@ public abstract class AccessTokenProvider {
*
* @return true if the token is expiring in next 5 minutes
*/
private boolean isTokenAboutToExpire() {
protected boolean isTokenAboutToExpire() {
if (token == null) {
LOG.debug("AADToken: no token. Returning expiring=true");
return true; // no token should have same response as expired token

View File

@ -137,7 +137,7 @@ public final class AzureADAuthenticator {
headers.put("Metadata", "true");
LOG.debug("AADToken: starting to fetch token using MSI");
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET");
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true);
}
/**
@ -258,8 +258,13 @@ public final class AzureADAuthenticator {
}
private static AzureADToken getTokenCall(String authEndpoint, String body,
Hashtable<String, String> headers, String httpMethod)
throws IOException {
Hashtable<String, String> headers, String httpMethod) throws IOException {
return getTokenCall(authEndpoint, body, headers, httpMethod, false);
}
private static AzureADToken getTokenCall(String authEndpoint, String body,
Hashtable<String, String> headers, String httpMethod, boolean isMsi)
throws IOException {
AzureADToken token = null;
ExponentialRetryPolicy retryPolicy
= new ExponentialRetryPolicy(3, 0, 1000, 2);
@ -272,7 +277,7 @@ public final class AzureADAuthenticator {
httperror = 0;
ex = null;
try {
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod);
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod, isMsi);
} catch (HttpException e) {
httperror = e.httpErrorCode;
ex = e;
@ -288,8 +293,9 @@ public final class AzureADAuthenticator {
return token;
}
private static AzureADToken getTokenSingleCall(
String authEndpoint, String payload, Hashtable<String, String> headers, String httpMethod)
private static AzureADToken getTokenSingleCall(String authEndpoint,
String payload, Hashtable<String, String> headers, String httpMethod,
boolean isMsi)
throws IOException {
AzureADToken token = null;
@ -336,7 +342,7 @@ public final class AzureADAuthenticator {
if (httpResponseCode == HttpURLConnection.HTTP_OK
&& responseContentType.startsWith("application/json") && responseContentLength > 0) {
InputStream httpResponseStream = conn.getInputStream();
token = parseTokenFromStream(httpResponseStream);
token = parseTokenFromStream(httpResponseStream, isMsi);
} else {
InputStream stream = conn.getErrorStream();
if (stream == null) {
@ -390,10 +396,12 @@ public final class AzureADAuthenticator {
return token;
}
private static AzureADToken parseTokenFromStream(InputStream httpResponseStream) throws IOException {
private static AzureADToken parseTokenFromStream(
InputStream httpResponseStream, boolean isMsi) throws IOException {
AzureADToken token = new AzureADToken();
try {
int expiryPeriod = 0;
int expiryPeriodInSecs = 0;
long expiresOnInSecs = -1;
JsonFactory jf = new JsonFactory();
JsonParser jp = jf.createJsonParser(httpResponseStream);
@ -408,17 +416,38 @@ public final class AzureADAuthenticator {
if (fieldName.equals("access_token")) {
token.setAccessToken(fieldValue);
}
if (fieldName.equals("expires_in")) {
expiryPeriod = Integer.parseInt(fieldValue);
expiryPeriodInSecs = Integer.parseInt(fieldValue);
}
if (fieldName.equals("expires_on")) {
expiresOnInSecs = Long.parseLong(fieldValue);
}
}
jp.nextToken();
}
jp.close();
long expiry = System.currentTimeMillis();
expiry = expiry + expiryPeriod * 1000L; // convert expiryPeriod to milliseconds and add
token.setExpiry(new Date(expiry));
LOG.debug("AADToken: fetched token with expiry " + token.getExpiry().toString());
if (expiresOnInSecs > 0) {
LOG.debug("Expiry based on expires_on: {}", expiresOnInSecs);
token.setExpiry(new Date(expiresOnInSecs * 1000));
} else {
if (isMsi) {
// Currently there is a known issue that MSI does not update expires_in
// for refresh and will have the value from first AAD token fetch request.
// Due to this known limitation, expires_in is not supported for MSI token fetch flow.
throw new UnsupportedOperationException("MSI Responded with invalid expires_on");
}
LOG.debug("Expiry based on expires_in: {}", expiryPeriodInSecs);
long expiry = System.currentTimeMillis();
expiry = expiry + expiryPeriodInSecs * 1000L; // convert expiryPeriod to milliseconds and add
token.setExpiry(new Date(expiry));
}
LOG.debug("AADToken: fetched token with expiry {}, expiresOn passed: {}",
token.getExpiry().toString(), expiresOnInSecs);
} catch (Exception ex) {
LOG.debug("AADToken: got exception when parsing json token " + ex.toString());
throw ex;

View File

@ -36,6 +36,10 @@ public class MsiTokenProvider extends AccessTokenProvider {
private final String clientId;
private long tokenFetchTime = -1;
private static final long ONE_HOUR = 3600 * 1000;
private static final Logger LOG = LoggerFactory.getLogger(AccessTokenProvider.class);
public MsiTokenProvider(final String authEndpoint, final String tenantGuid,
@ -51,6 +55,36 @@ public class MsiTokenProvider extends AccessTokenProvider {
LOG.debug("AADToken: refreshing token from MSI");
AzureADToken token = AzureADAuthenticator
.getTokenFromMsi(authEndpoint, tenantGuid, clientId, authority, false);
tokenFetchTime = System.currentTimeMillis();
return token;
}
/**
* Checks if the token is about to expire as per base expiry logic.
* Otherwise try to expire every 1 hour
*
* @return true if the token is expiring in next 1 hour or if a token has
* never been fetched
*/
@Override
protected boolean isTokenAboutToExpire() {
if (tokenFetchTime == -1 || super.isTokenAboutToExpire()) {
return true;
}
boolean expiring = false;
long elapsedTimeSinceLastTokenRefreshInMillis =
System.currentTimeMillis() - tokenFetchTime;
expiring = elapsedTimeSinceLastTokenRefreshInMillis >= ONE_HOUR
|| elapsedTimeSinceLastTokenRefreshInMillis < 0;
// In case of, Token is not refreshed for 1 hr or any clock skew issues,
// refresh token.
if (expiring) {
LOG.debug("MSIToken: token renewing. Time elapsed since last token fetch:"
+ " {} milli seconds", elapsedTimeSinceLastTokenRefreshInMillis);
}
return expiring;
}
}

View File

@ -0,0 +1,93 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.fs.azurebfs;
import java.io.IOException;
import java.util.Date;
import org.junit.Test;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.fs.azurebfs.oauth2.AccessTokenProvider;
import org.apache.hadoop.fs.azurebfs.oauth2.AzureADToken;
import org.apache.hadoop.fs.azurebfs.oauth2.MsiTokenProvider;
import static org.junit.Assume.assumeThat;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.isEmptyOrNullString;
import static org.hamcrest.Matchers.isEmptyString;
import static org.apache.hadoop.fs.azurebfs.constants.AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY;
import static org.apache.hadoop.fs.azurebfs.constants.AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT;
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID;
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY;
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT;
import static org.apache.hadoop.fs.azurebfs.constants.ConfigurationKeys.FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT;
/**
* Test MsiTokenProvider.
*/
public final class ITestAbfsMsiTokenProvider
extends AbstractAbfsIntegrationTest {
public ITestAbfsMsiTokenProvider() throws Exception {
super();
}
@Test
public void test() throws IOException {
AbfsConfiguration conf = getConfiguration();
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT),
not(isEmptyOrNullString()));
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT),
not(isEmptyOrNullString()));
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID),
not(isEmptyOrNullString()));
assumeThat(conf.get(FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY),
not(isEmptyOrNullString()));
String tenantGuid = conf
.getPasswordString(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT);
String clientId = conf.getPasswordString(FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID);
String authEndpoint = getTrimmedPasswordString(conf,
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT,
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT);
String authority = getTrimmedPasswordString(conf,
FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY,
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY);
AccessTokenProvider tokenProvider = new MsiTokenProvider(authEndpoint,
tenantGuid, clientId, authority);
AzureADToken token = null;
token = tokenProvider.getToken();
assertThat(token.getAccessToken(), not(isEmptyString()));
assertThat(token.getExpiry().after(new Date()), is(true));
}
private String getTrimmedPasswordString(AbfsConfiguration conf, String key,
String defaultValue) throws IOException {
String value = conf.getPasswordString(key);
if (StringUtils.isBlank(value)) {
value = defaultValue;
}
return value.trim();
}
}