mirror of https://github.com/apache/nifi.git
NIFI-11386 Added Resource and Audience support to StandardOauth2AccessTokenProvider
- Also keeping previous Refresh Token if one is not provided during a refresh request This closes #7164 Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
parent
1466b7d7d3
commit
88587f5c02
|
@ -163,6 +163,22 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
|
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor RESOURCE = new PropertyDescriptor.Builder()
|
||||||
|
.name("resource")
|
||||||
|
.displayName("Resource")
|
||||||
|
.description("Resource URI for the access token request defined in RFC 8707 Section 2")
|
||||||
|
.required(false)
|
||||||
|
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
public static final PropertyDescriptor AUDIENCE = new PropertyDescriptor.Builder()
|
||||||
|
.name("audience")
|
||||||
|
.displayName("Audience")
|
||||||
|
.description("Audience for the access token request defined in RFC 8693 Section 2.1")
|
||||||
|
.required(false)
|
||||||
|
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
|
||||||
|
.build();
|
||||||
|
|
||||||
public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
|
public static final PropertyDescriptor REFRESH_WINDOW = new PropertyDescriptor.Builder()
|
||||||
.name("refresh-window")
|
.name("refresh-window")
|
||||||
.displayName("Refresh Window")
|
.displayName("Refresh Window")
|
||||||
|
@ -199,6 +215,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
CLIENT_ID,
|
CLIENT_ID,
|
||||||
CLIENT_SECRET,
|
CLIENT_SECRET,
|
||||||
SCOPE,
|
SCOPE,
|
||||||
|
RESOURCE,
|
||||||
|
AUDIENCE,
|
||||||
REFRESH_WINDOW,
|
REFRESH_WINDOW,
|
||||||
SSL_CONTEXT,
|
SSL_CONTEXT,
|
||||||
HTTP_PROTOCOL_STRATEGY
|
HTTP_PROTOCOL_STRATEGY
|
||||||
|
@ -220,6 +238,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
private volatile String clientId;
|
private volatile String clientId;
|
||||||
private volatile String clientSecret;
|
private volatile String clientSecret;
|
||||||
private volatile String scope;
|
private volatile String scope;
|
||||||
|
private volatile String resource;
|
||||||
|
private volatile String audience;
|
||||||
private volatile long refreshWindowSeconds;
|
private volatile long refreshWindowSeconds;
|
||||||
|
|
||||||
private volatile AccessToken accessDetails;
|
private volatile AccessToken accessDetails;
|
||||||
|
@ -242,6 +262,8 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
|
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
|
||||||
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
|
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
|
||||||
scope = context.getProperty(SCOPE).getValue();
|
scope = context.getProperty(SCOPE).getValue();
|
||||||
|
resource = context.getProperty(RESOURCE).getValue();
|
||||||
|
audience = context.getProperty(AUDIENCE).getValue();
|
||||||
|
|
||||||
if (context.getProperty(REFRESH_TOKEN).isSet()) {
|
if (context.getProperty(REFRESH_TOKEN).isSet()) {
|
||||||
String refreshToken = context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue();
|
String refreshToken = context.getProperty(REFRESH_TOKEN).evaluateAttributeExpressions().getValue();
|
||||||
|
@ -319,6 +341,14 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
return accessDetails;
|
return accessDetails;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean isRefreshRequired() {
|
||||||
|
final Instant expirationRefreshTime = accessDetails.getFetchTime()
|
||||||
|
.plusSeconds(accessDetails.getExpiresIn())
|
||||||
|
.minusSeconds(refreshWindowSeconds);
|
||||||
|
|
||||||
|
return Instant.now().isAfter(expirationRefreshTime);
|
||||||
|
}
|
||||||
|
|
||||||
private void acquireAccessDetails() {
|
private void acquireAccessDetails() {
|
||||||
getLogger().debug("New Access Token request started [{}]", authorizationServerUrl);
|
getLogger().debug("New Access Token request started [{}]", authorizationServerUrl);
|
||||||
|
|
||||||
|
@ -332,58 +362,59 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
acquireTokenBuilder.add("grant_type", "client_credentials");
|
acquireTokenBuilder.add("grant_type", "client_credentials");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) {
|
addFormData(acquireTokenBuilder);
|
||||||
acquireTokenBuilder.add("client_id", clientId);
|
|
||||||
acquireTokenBuilder.add("client_secret", clientSecret);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (scope != null) {
|
this.accessDetails = requestToken(acquireTokenBuilder);
|
||||||
acquireTokenBuilder.add("scope", scope);
|
|
||||||
}
|
|
||||||
|
|
||||||
RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
|
|
||||||
Request.Builder acquireTokenRequestBuilder = new Request.Builder()
|
|
||||||
.url(authorizationServerUrl)
|
|
||||||
.post(acquireTokenRequestBody);
|
|
||||||
|
|
||||||
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
|
|
||||||
acquireTokenRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
|
|
||||||
}
|
|
||||||
|
|
||||||
Request acquireTokenRequest = acquireTokenRequestBuilder.build();
|
|
||||||
|
|
||||||
this.accessDetails = getAccessDetails(acquireTokenRequest);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void refreshAccessDetails() {
|
private void refreshAccessDetails() {
|
||||||
getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl);
|
getLogger().debug("Refresh Access Token request started [{}]", authorizationServerUrl);
|
||||||
|
|
||||||
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
|
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
|
||||||
.add("grant_type", "refresh_token")
|
.add("grant_type", "refresh_token")
|
||||||
.add("refresh_token", this.accessDetails.getRefreshToken());
|
.add("refresh_token", this.accessDetails.getRefreshToken());
|
||||||
|
|
||||||
if (ClientAuthenticationStrategy.REQUEST_BODY == clientAuthenticationStrategy && clientId != null) {
|
addFormData(refreshTokenBuilder);
|
||||||
refreshTokenBuilder.add("client_id", clientId);
|
|
||||||
refreshTokenBuilder.add("client_secret", clientSecret);
|
AccessToken newAccessDetails = requestToken(refreshTokenBuilder);
|
||||||
|
|
||||||
|
if (newAccessDetails.getRefreshToken() == null) {
|
||||||
|
newAccessDetails.setRefreshToken(this.accessDetails.getRefreshToken());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.accessDetails = newAccessDetails;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addFormData(FormBody.Builder formBuilder) {
|
||||||
|
if (clientAuthenticationStrategy == ClientAuthenticationStrategy.REQUEST_BODY && clientId != null) {
|
||||||
|
formBuilder.add("client_id", clientId);
|
||||||
|
formBuilder.add("client_secret", clientSecret);
|
||||||
|
}
|
||||||
if (scope != null) {
|
if (scope != null) {
|
||||||
refreshTokenBuilder.add("scope", scope);
|
formBuilder.add("scope", scope);
|
||||||
}
|
}
|
||||||
|
if (resource != null) {
|
||||||
|
formBuilder.add("resource", resource);
|
||||||
|
}
|
||||||
|
if (audience != null) {
|
||||||
|
formBuilder.add("audience", audience);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
|
private AccessToken requestToken(FormBody.Builder formBuilder) {
|
||||||
|
RequestBody requestBody = formBuilder.build();
|
||||||
|
|
||||||
Request.Builder refreshRequestBuilder = new Request.Builder()
|
Request.Builder requestBuilder = new Request.Builder()
|
||||||
.url(authorizationServerUrl)
|
.url(authorizationServerUrl)
|
||||||
.post(refreshTokenRequestBody);
|
.post(requestBody);
|
||||||
|
|
||||||
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
|
if (ClientAuthenticationStrategy.BASIC_AUTHENTICATION == clientAuthenticationStrategy && clientId != null) {
|
||||||
refreshRequestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
|
requestBuilder.addHeader(AUTHORIZATION_HEADER, Credentials.basic(clientId, clientSecret));
|
||||||
}
|
}
|
||||||
|
|
||||||
Request refreshRequest = refreshRequestBuilder.build();
|
Request request = requestBuilder.build();
|
||||||
|
|
||||||
this.accessDetails = getAccessDetails(refreshRequest);
|
return getAccessDetails(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
private AccessToken getAccessDetails(final Request newRequest) {
|
private AccessToken getAccessDetails(final Request newRequest) {
|
||||||
|
@ -402,14 +433,6 @@ public class StandardOauth2AccessTokenProvider extends AbstractControllerService
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isRefreshRequired() {
|
|
||||||
final Instant expirationRefreshTime = accessDetails.getFetchTime()
|
|
||||||
.plusSeconds(accessDetails.getExpiresIn())
|
|
||||||
.minusSeconds(refreshWindowSeconds);
|
|
||||||
|
|
||||||
return Instant.now().isAfter(expirationRefreshTime);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<ConfigVerificationResult> verify(ConfigurationContext context, ComponentLog verificationLogger, Map<String, String> variables) {
|
public List<ConfigVerificationResult> verify(ConfigurationContext context, ComponentLog verificationLogger, Map<String, String> variables) {
|
||||||
ConfigVerificationResult.Builder builder = new ConfigVerificationResult.Builder()
|
ConfigVerificationResult.Builder builder = new ConfigVerificationResult.Builder()
|
||||||
|
|
|
@ -78,6 +78,9 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||||
private static final String PASSWORD = "password";
|
private static final String PASSWORD = "password";
|
||||||
private static final String CLIENT_ID = "clientId";
|
private static final String CLIENT_ID = "clientId";
|
||||||
private static final String CLIENT_SECRET = "clientSecret";
|
private static final String CLIENT_SECRET = "clientSecret";
|
||||||
|
private static final String SCOPE = "scope";
|
||||||
|
private static final String RESOURCE = "resource";
|
||||||
|
private static final String AUDIENCE = "audience";
|
||||||
private static final long FIVE_MINUTES = 300;
|
private static final long FIVE_MINUTES = 300;
|
||||||
|
|
||||||
private static final int HTTP_OK = 200;
|
private static final int HTTP_OK = 200;
|
||||||
|
@ -120,6 +123,9 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.SCOPE).getValue()).thenReturn(SCOPE);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.RESOURCE).getValue()).thenReturn(RESOURCE);
|
||||||
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUDIENCE).getValue()).thenReturn(AUDIENCE);
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.REFRESH_WINDOW).asTimePeriod(eq(TimeUnit.SECONDS))).thenReturn(FIVE_MINUTES);
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.BASIC_AUTHENTICATION.getValue());
|
||||||
|
|
||||||
|
@ -361,6 +367,57 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||||
assertEquals(expectedToken, actualToken);
|
assertEquals(expectedToken, actualToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testKeepPreviousRefreshTokenWhenNewOneIsNotProvided() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String refreshTokenBeforeRefresh = "refresh_token";
|
||||||
|
|
||||||
|
Response response1 = buildResponse(
|
||||||
|
HTTP_OK,
|
||||||
|
"{ \"access_token\":\"not_checking_in_this_test\", \"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
Response response2 = buildResponse(
|
||||||
|
HTTP_OK,
|
||||||
|
"{ \"access_token\":\"not_checking_in_this_test_either\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
String refreshTokenAfterRefresh = testSubject.getAccessDetails().getRefreshToken();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
assertEquals(refreshTokenBeforeRefresh, refreshTokenAfterRefresh);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOverwritePreviousRefreshTokenWhenNewOneIsProvided() throws Exception {
|
||||||
|
// GIVEN
|
||||||
|
String refreshTokenBeforeRefresh = "refresh_token_before_refresh";
|
||||||
|
String expectedRefreshTokenAfterRefresh = "refresh_token_after_refresh";
|
||||||
|
|
||||||
|
Response response1 = buildResponse(
|
||||||
|
HTTP_OK,
|
||||||
|
"{ \"access_token\":\"not_checking_in_this_test\", \"expires_in\":\"0\", \"refresh_token\":\"" + refreshTokenBeforeRefresh + "\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
Response response2 = buildResponse(
|
||||||
|
HTTP_OK,
|
||||||
|
"{ \"access_token\":\"not_checking_in_this_test_either\", \"refresh_token\":\"" + expectedRefreshTokenAfterRefresh + "\" }"
|
||||||
|
);
|
||||||
|
|
||||||
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
|
||||||
|
|
||||||
|
// WHEN
|
||||||
|
testSubject.getAccessDetails();
|
||||||
|
String actualRefreshTokenAfterRefresh = testSubject.getAccessDetails().getRefreshToken();
|
||||||
|
|
||||||
|
// THEN
|
||||||
|
assertEquals(expectedRefreshTokenAfterRefresh, actualRefreshTokenAfterRefresh);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicAuthentication() throws Exception {
|
public void testBasicAuthentication() throws Exception {
|
||||||
// GIVEN
|
// GIVEN
|
||||||
|
@ -377,7 +434,7 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRequestBodyAuthentication() throws Exception {
|
public void testRequestBodyFormData() throws Exception {
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE.getValue());
|
||||||
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
|
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_AUTHENTICATION_STRATEGY).getValue()).thenReturn(ClientAuthenticationStrategy.REQUEST_BODY.getValue());
|
||||||
testSubject.onEnabled(mockContext);
|
testSubject.onEnabled(mockContext);
|
||||||
|
@ -385,7 +442,11 @@ public class StandardOauth2AccessTokenProviderTest {
|
||||||
// GIVEN
|
// GIVEN
|
||||||
Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
|
Response response = buildResponse(HTTP_OK, "{\"access_token\":\"foobar\"}");
|
||||||
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
|
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
|
||||||
String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID + "&client_secret=" + CLIENT_SECRET;
|
String expected = "grant_type=client_credentials&client_id=" + CLIENT_ID
|
||||||
|
+ "&client_secret=" + CLIENT_SECRET
|
||||||
|
+ "&scope=" + SCOPE
|
||||||
|
+ "&resource=" + RESOURCE
|
||||||
|
+ "&audience=" + AUDIENCE;
|
||||||
|
|
||||||
// WHEN
|
// WHEN
|
||||||
testSubject.getAccessDetails();
|
testSubject.getAccessDetails();
|
||||||
|
|
Loading…
Reference in New Issue