NIFI-11386 Added Resource and Audience support to StandardOauth2AccessTokenProvider

- Also keeping previous Refresh Token if one is not provided during a refresh request

This closes #7164

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Tamas Palfy 2023-04-11 17:10:07 +02:00 committed by exceptionfactory
parent 1466b7d7d3
commit 88587f5c02
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
2 changed files with 127 additions and 43 deletions

View File

@ -163,6 +163,22 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR) .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build(); .build();
public static final PropertyDescriptor RESOURCE = new PropertyDescriptor.Builder()
.name("resource")
.displayName("Resource")
.description("Resource URI for the access token request defined in RFC 8707 Section 2")
.required(false)
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build();
public static final PropertyDescriptor AUDIENCE = new PropertyDescriptor.Builder()
.name("audience")
.displayName("Audience")
.description("Audience for the access token request defined in RFC 8693 Section 2.1")
.required(false)
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.build();
public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder() public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
.name("refresh-window") .name("refresh-window")
.displayName("Refresh Window") .displayName("Refresh Window")
@ -199,6 +215,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
CLIENT_ID, CLIENT_ID,
CLIENT_SECRET, CLIENT_SECRET,
SCOPE, SCOPE,
RESOURCE,
AUDIENCE,
REFRESH_WINDOW, REFRESH_WINDOW,
SSL_CONTEXT, SSL_CONTEXT,
HTTP_PROTOCOL_STRATEGY HTTP_PROTOCOL_STRATEGY
@ -220,6 +238,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
private volatile String clientId; private volatile String clientId;
private volatile String clientSecret; private volatile String clientSecret;
private volatile String scope; private volatile String scope;
private volatile String resource;
private volatile String audience;
private volatile long refreshWindowSeconds; private volatile long refreshWindowSeconds;
private volatile AccessToken accessDetails; private volatile AccessToken accessDetails;
@ -242,6 +262,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue(); clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
clientSecret = context.getProperty(CLIENT_SECRET).getValue(); clientSecret = context.getProperty(CLIENT_SECRET).getValue();
scope = context.getProperty(SCOPE).getValue(); scope = context.getProperty(SCOPE).getValue();
resource = context.getProperty(RESOURCE).getValue();
audience = context.getProperty(AUDIENCE).getValue();
if (context.getProperty(REFRESH_TOKEN).isSet()) { if (context.getProperty(REFRESH_TOKEN).isSet()) {
String refreshToken = context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue(); String refreshToken = context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue();
@ -319,6 +341,14 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
return accessDetails; return accessDetails;
} }
private boolean isRefreshRequired() {
final Instant expirationRefreshTime = accessDetails.getFetchTime()
.plusSeconds(accessDetails.getExpiresIn())
.minusSeconds(refreshWindowSeconds);
return Instant.now().isAfter(expirationRefreshTime);
}
private void acquireAccessDetails() { private void acquireAccessDetails() {
getLogger().debug("New Access Token request started [{}]", authorizationServerUrl); getLogger().debug("New Access Token request started [{}]", authorizationServerUrl);
@ -332,58 +362,59 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
acquireTokenBuilder.add("grant_type", "client_credentials"); acquireTokenBuilder.add("grant_type", "client_credentials");
} }
if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) { addFormData(acquireTokenBuilder);
acquireTokenBuilder.add("client_id", clientId);
acquireTokenBuilder.add("client_secret", clientSecret);
}
if (scope != null) { this.accessDetails = requestToken(acquireTokenBuilder);
acquireTokenBuilder.add("scope", scope);
}
RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
Request.Builder acquireTokenRequestBuilder = new Request.Builder()
.url(authorizationServerUrl)
.post(acquireTokenRequestBody);
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
}
Request acquireTokenRequest = acquireTokenRequestBuilder.build();
this.accessDetails = getAccessDetails(acquireTokenRequest);
} }
private void refreshAccessDetails() { private void refreshAccessDetails() {
getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl); getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl);
FormBody.Builder refreshTokenBuilder = new FormBody.Builder() FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
.add("grant_type", "refresh_token") .add("grant_type", "refresh_token")
.add("refresh_token", this.accessDetails.getRefreshToken()); .add("refresh_token", this.accessDetails.getRefreshToken());
if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) { addFormData(refreshTokenBuilder);
refreshTokenBuilder.add("client_id", clientId);
refreshTokenBuilder.add("client_secret", clientSecret); AccessToken newAccessDetails = requestToken(refreshTokenBuilder);
if (newAccessDetails.getRefreshToken() == null) {
newAccessDetails.setRefreshToken(this.accessDetails.getRefreshToken());
} }
this.accessDetails = newAccessDetails;
}
private void addFormData(FormBody.Builder formBuilder) {
if (clientAuthenticationStrategy == ClientAuthenticationStrategy.REQUEST_BODY && clientId != null) {
formBuilder.add("client_id", clientId);
formBuilder.add("client_secret", clientSecret);
}
if (scope != null) { if (scope != null) {
refreshTokenBuilder.add("scope", scope); formBuilder.add("scope", scope);
} }
if (resource != null) {
formBuilder.add("resource", resource);
}
if (audience != null) {
formBuilder.add("audience", audience);
}
}
RequestBody refreshTokenRequestBody = refreshTokenBuilder.build(); private AccessToken requestToken(FormBody.Builder formBuilder) {
RequestBody requestBody = formBuilder.build();
Request.Builder refreshRequestBuilder = new Request.Builder() Request.Builder requestBuilder = new Request.Builder()
.url(authorizationServerUrl) .url(authorizationServerUrl)
.post(refreshTokenRequestBody); .post(requestBody);
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) { if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret)); requestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
} }
Request refreshRequest = refreshRequestBuilder.build(); Request request = requestBuilder.build();
this.accessDetails = getAccessDetails(refreshRequest); return getAccessDetails(request);
} }
private AccessToken getAccessDetails(final Request newRequest) { private AccessToken getAccessDetails(final Request newRequest) {
@ -402,14 +433,6 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
} }
} }
private boolean isRefreshRequired() {
final Instant expirationRefreshTime = accessDetails.getFetchTime()
.plusSeconds(accessDetails.getExpiresIn())
.minusSeconds(refreshWindowSeconds);
return Instant.now().isAfter(expirationRefreshTime);
}
@Override @Override
public List<ConfigVerificationResult> verify(ConfigurationContext context, ComponentLog verificationLogger, Map<String, String> variables) { public List<ConfigVerificationResult> verify(ConfigurationContext context, ComponentLog verificationLogger, Map<String, String> variables) {
ConfigVerificationResult.Builder builder = new ConfigVerificationResult.Builder() ConfigVerificationResult.Builder builder = new ConfigVerificationResult.Builder()

View File

@ -78,6 +78,9 @@ public class StandardOauth2AccessTokenProviderTest {
private static final String PASSWORD = "password"; private static final String PASSWORD = "password";
private static final String CLIENT_ID = "clientId"; private static final String CLIENT_ID = "clientId";
private static final String CLIENT_SECRET = "clientSecret"; private static final String CLIENT_SECRET = "clientSecret";
private static final String SCOPE = "scope";
private static final String RESOURCE = "resource";
private static final String AUDIENCE = "audience";
private static final long FIVE_MINUTES = 300; private static final long FIVE_MINUTES = 300;
private static final int HTTP_OK = 200; private static final int HTTP_OK = 200;
@ -120,6 +123,9 @@ public class StandardOauth2AccessTokenProviderTest {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.SCOPE).getValue()).thenReturn(SCOPE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.RESOURCE).getValue()).thenReturn(RESOURCE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUDIENCE).getValue()).thenReturn(AUDIENCE);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue()); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
@ -361,6 +367,57 @@ public class StandardOauth2AccessTokenProviderTest {
assertEquals(expectedToken, actualToken); assertEquals(expectedToken, actualToken);
} }
@Test
public void testKeepPreviousRefreshTokenWhenNewOneIsNotProvided() throws Exception {
// GIVEN
String refreshTokenBeforeRefresh = "refresh_token";
Response response1 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test\", \"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
);
Response response2 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test_either\" }"
);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails();
String refreshTokenAfterRefresh = testSubject.getAccessDetails().getRefreshToken();
// THEN
assertEquals(refreshTokenBeforeRefresh, refreshTokenAfterRefresh);
}
@Test
public void testOverwritePreviousRefreshTokenWhenNewOneIsProvided() throws Exception {
// GIVEN
String refreshTokenBeforeRefresh = "refresh_token_before_refresh";
String expectedRefreshTokenAfterRefresh = "refresh_token_after_refresh";
Response response1 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test\", \"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
);
Response response2 = buildResponse(
HTTP_OK,
"{ \"access_token\":\"not_checking_in_this_test_either\", \"refresh_token\":\"" + expectedRefreshTokenAfterRefresh + "\" }"
);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails();
String actualRefreshTokenAfterRefresh = testSubject.getAccessDetails().getRefreshToken();
// THEN
assertEquals(expectedRefreshTokenAfterRefresh, actualRefreshTokenAfterRefresh);
}
@Test @Test
public void testBasicAuthentication() throws Exception { public void testBasicAuthentication() throws Exception {
// GIVEN // GIVEN
@ -377,7 +434,7 @@ public class StandardOauth2AccessTokenProviderTest {
} }
@Test @Test
public void testRequestBodyAuthentication() throws Exception { public void testRequestBodyFormData() throws Exception {
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue()); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue()); when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
testSubject.onEnabled(mockContext); testSubject.onEnabled(mockContext);
@ -385,7 +442,11 @@ public class StandardOauth2AccessTokenProviderTest {
// GIVEN // GIVEN
Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}"); Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response); when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID + "&client_secret=" + CLIENT_SECRET; String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID
+ "&client_secret=" + CLIENT_SECRET
+ "&scope=" + SCOPE
+ "&resource=" + RESOURCE
+ "&audience=" + AUDIENCE;
// WHEN // WHEN
testSubject.getAccessDetails(); testSubject.getAccessDetails();