NIFI-11386 Added Resource and Audience support to StandardOauth2AccessTokenProvider

- Also keeping previous Refresh Token if one is not provided during a refresh request

This closes #7164

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Tamas Palfy 2023-04-11 17:10:07 +02:00 committed by exceptionfactory
parent 1466b7d7d3
commit 88587f5c02
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
2 changed files with 127 additions and 43 deletions

View File

@ -163,6 +163,22 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build();
public static final PropertyDescriptor RESOURCE = new PropertyDescriptor.Builder()
.name("resource")
.displayName("Resource")
.description("Resource URI for the access token request defined in RFC 8707 Section 2")
.required(false)
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build();
public static final PropertyDescriptor AUDIENCE = new PropertyDescriptor.Builder()
.name("audience")
.displayName("Audience")
.description("Audience for the access token request defined in RFC 8693 Section 2.1")
.required(false)
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build();
public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
.name("refresh-window")
.displayName("Refresh Window")
@ -199,6 +215,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
CLIENT_ID,
CLIENT_SECRET,
SCOPE,
RESOURCE,
AUDIENCE,
REFRESH_WINDOW,
SSL_CONTEXT,
HTTP_PROTOCOL_STRATEGY
@ -220,6 +238,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
private volatile String clientId;
private volatile String clientSecret;
private volatile String scope;
private volatile String resource;
private volatile String audience;
private volatile long refreshWindowSeconds;
private volatile AccessToken accessDetails;
@ -242,6 +262,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
scope = context.getProperty(SCOPE).getValue();
resource = context.getProperty(RESOURCE).getValue();
audience = context.getProperty(AUDIENCE).getValue();
if (context.getProperty(REFRESH_TOKEN).isSet()) {
String refreshToken = context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue();
@ -319,6 +341,14 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
return accessDetails;
}
private boolean isRefreshRequired() {
final Instant expirationRefreshTime = accessDetails.getFetchTime()
.plusSeconds(accessDetails.getExpiresIn())
.minusSeconds(refreshWindowSeconds);
return Instant.now().isAfter(expirationRefreshTime);
}
private void acquireAccessDetails() {
getLogger().debug("New Access Token request started [{}]", authorizationServerUrl);
@ -332,58 +362,59 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
acquireTokenBuilder.add("grant_type", "client_credentials");
}
if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) {
acquireTokenBuilder.add("client_id", clientId);
acquireTokenBuilder.add("client_secret", clientSecret);
}
addFormData(acquireTokenBuilder);
if (scope != null) {
acquireTokenBuilder.add("scope", scope);
}
RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
Request.Builder acquireTokenRequestBuilder = new Request.Builder()
.url(authorizationServerUrl)
.post(acquireTokenRequestBody);
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
}
Request acquireTokenRequest = acquireTokenRequestBuilder.build();
this.accessDetails = getAccessDetails(acquireTokenRequest);
this.accessDetails = requestToken(acquireTokenBuilder);
}
private void refreshAccessDetails() {
getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl);
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
.add("grant_type", "refresh_token")
.add("refresh_token", this.accessDetails.getRefreshToken());
.add("grant_type", "refresh_token")
.add("refresh_token", this.accessDetails.getRefreshToken());
if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) {
refreshTokenBuilder.add("client_id", clientId);
refreshTokenBuilder.add("client_secret", clientSecret);
addFormData(refreshTokenBuilder);
AccessToken newAccessDetails = requestToken(refreshTokenBuilder);
if (newAccessDetails.getRefreshToken() == null) {
newAccessDetails.setRefreshToken(this.accessDetails.getRefreshToken());
}
this.accessDetails = newAccessDetails;
}
private void addFormData(FormBody.Builder formBuilder) {
if (clientAuthenticationStrategy == ClientAuthenticationStrategy.REQUEST_BODY && clientId != null) {
formBuilder.add("client_id", clientId);
formBuilder.add("client_secret", clientSecret);
}
if (scope != null) {
refreshTokenBuilder.add("scope", scope);
formBuilder.add("scope", scope);
}
if (resource != null) {
formBuilder.add("resource", resource);
}
if (audience != null) {
formBuilder.add("audience", audience);
}
}
RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
private AccessToken requestToken(FormBody.Builder formBuilder) {
RequestBody requestBody = formBuilder.build();
Request.Builder refreshRequestBuilder = new Request.Builder()
.url(authorizationServerUrl)
.post(refreshTokenRequestBody);
Request.Builder requestBuilder = new Request.Builder()
.url(authorizationServerUrl)
.post(requestBody);
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
requestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
}
Request refreshRequest = refreshRequestBuilder.build();
Request request = requestBuilder.build();
this.accessDetails = getAccessDetails(refreshRequest);
return getAccessDetails(request);
}
private AccessToken getAccessDetails(final Request newRequest) {
@ -402,14 +433,6 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
}
}
private boolean isRefreshRequired() {
final Instant expirationRefreshTime = accessDetails.getFetchTime()
.plusSeconds(accessDetails.getExpiresIn())
.minusSeconds(refreshWindowSeconds);
return Instant.now().isAfter(expirationRefreshTime);
}
@Override
public List<ConfigVerificationResult> verify(ConfigurationContext context, ComponentLog verificationLogger, Map<String, String> variables) {
ConfigVerificationResult.Builder builder = new ConfigVerificationResult.Builder()

View File

@ -78,6 +78,9 @@ public class StandardOauth2AccessTokenProviderTest {
private static final String PASSWORD = "password";
private static final String CLIENT_ID = "clientId";
private static final String CLIENT_SECRET = "clientSecret";
private static final String SCOPE = "scope";
private static final String RESOURCE = "resource";
private static final String AUDIENCE = "audience";
private static final long FIVE_MINUTES = 300;
private static final int HTTP_OK = 200;
@ -120,6 +123,9 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.SCOPE).getValue()).thenReturn(SCOPE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.RESOURCE).getValue()).thenReturn(RESOURCE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUDIENCE).getValue()).thenReturn(AUDIENCE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
@ -361,6 +367,57 @@ public class StandardOauth2AccessTokenProviderTest {
assertEquals(expectedToken, actualToken);
}
@Test
public void testKeepPreviousRefreshTokenWhenNewOneIsNotProvided() throws Exception {
// GIVEN
String refreshTokenBeforeRefresh = "refresh_token";
Response response1 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test\", \"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
);
Response response2 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test_either\" }"
);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails();
String refreshTokenAfterRefresh = testSubject.getAccessDetails().getRefreshToken();
// THEN
assertEquals(refreshTokenBeforeRefresh, refreshTokenAfterRefresh);
}
@Test
public void testOverwritePreviousRefreshTokenWhenNewOneIsProvided() throws Exception {
// GIVEN
String refreshTokenBeforeRefresh = "refresh_token_before_refresh";
String expectedRefreshTokenAfterRefresh = "refresh_token_after_refresh";
Response response1 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test\", \"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
);
Response response2 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test_either\", \"refresh_token\":\"" + expectedRefreshTokenAfterRefresh + "\" }"
);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails();
String actualRefreshTokenAfterRefresh = testSubject.getAccessDetails().getRefreshToken();
// THEN
assertEquals(expectedRefreshTokenAfterRefresh, actualRefreshTokenAfterRefresh);
}
@Test
public void testBasicAuthentication() throws Exception {
// GIVEN
@ -377,7 +434,7 @@ public class StandardOauth2AccessTokenProviderTest {
}
@Test
public void testRequestBodyAuthentication() throws Exception {
public void testRequestBodyFormData() throws Exception {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
testSubject.onEnabled(mockContext);
@ -385,7 +442,11 @@ public class StandardOauth2AccessTokenProviderTest {
// GIVEN
Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID + "&client_secret=" + CLIENT_SECRET;
String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID
+ "&client_secret=" + CLIENT_SECRET
+ "&scope=" + SCOPE
+ "&resource=" + RESOURCE
+ "&audience=" + AUDIENCE;
// WHEN
testSubject.getAccessDetails();