NIFI-9065 Add support for OAuth2AccessTokenProvider in InvokeHTTP

Signed-off-by: Pierre Villard <pierre.villard.fr@gmail.com>

This closes #5319.
This commit is contained in:
Tamas Palfy 2021-07-27 15:19:53 +02:00 committed by Pierre Villard
parent e603b0179b
commit aa61494fc3
No known key found for this signature in database
GPG Key ID: F92A93B30C07C6D5
11 changed files with 944 additions and 24 deletions

View File

@ -374,6 +374,11 @@
<artifactId>nifi-database-utils</artifactId>
<version>1.16.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-oauth2-provider-api</artifactId>
<version>1.16.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.sshd</groupId>
<artifactId>sshd-core</artifactId>

View File

@ -42,6 +42,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@ -90,6 +91,7 @@ import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.flowfile.attributes.CoreAttributes;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.processor.AbstractProcessor;
import org.apache.nifi.processor.DataUnit;
import org.apache.nifi.processor.ProcessContext;
@ -494,6 +496,13 @@ public class InvokeHTTP extends AbstractProcessor {
.allowableValues("True", "False")
.build();
public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder()
.name("oauth2-access-token-provider")
.displayName("OAuth2 Access Token provider")
.identifiesControllerService(OAuth2AccessTokenProvider.class)
.required(false)
.build();
public static final PropertyDescriptor FLOW_FILE_NAMING_STRATEGY = new PropertyDescriptor.Builder()
.name("flow-file-naming-strategy")
.description("Determines the strategy used for setting the filename attribute of the FlowFile.")
@ -527,6 +536,7 @@ public class InvokeHTTP extends AbstractProcessor {
PROP_USERAGENT,
PROP_BASIC_AUTH_USERNAME,
PROP_BASIC_AUTH_PASSWORD,
OAUTH2_ACCESS_TOKEN_PROVIDER,
PROXY_CONFIGURATION_SERVICE,
PROP_PROXY_HOST,
PROP_PROXY_PORT,
@ -595,6 +605,8 @@ public class InvokeHTTP extends AbstractProcessor {
private volatile boolean useChunked = false;
private volatile Optional<OAuth2AccessTokenProvider> oauth2AccessTokenProviderOptional;
private final AtomicReference<OkHttpClient> okHttpClientAtomicReference = new AtomicReference<>();
@Override
@ -728,6 +740,19 @@ public class InvokeHTTP extends AbstractProcessor {
.build());
}
boolean usingUserNamePasswordAuthorization = validationContext.getProperty(PROP_BASIC_AUTH_USERNAME).isSet()
|| validationContext.getProperty(PROP_BASIC_AUTH_PASSWORD).isSet();
boolean usingOAuth2Authorization = validationContext.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet();
if (usingUserNamePasswordAuthorization && usingOAuth2Authorization) {
results.add(new ValidationResult.Builder()
.subject("Authorization properties")
.valid(false)
.explanation("OAuth2 Authorization cannot be configured together with Username and Password properties")
.build());
}
return results;
}
@ -806,6 +831,19 @@ public class InvokeHTTP extends AbstractProcessor {
okHttpClientAtomicReference.set(okHttpClientBuilder.build());
}
@OnScheduled
public void initOauth2AccessTokenProvider(final ProcessContext context) {
if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) {
OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
oauth2AccessTokenProvider.getAccessDetails();
oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider);
} else {
oauth2AccessTokenProviderOptional = Optional.empty();
}
}
private void setAuthenticator(OkHttpClient.Builder okHttpClientBuilder, ProcessContext context) {
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
@ -1034,11 +1072,17 @@ public class InvokeHTTP extends AbstractProcessor {
final String authUser = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_USERNAME).getValue());
// If the username/password properties are set then check if digest auth is being used
if (!authUser.isEmpty() && "false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
if ("false".equalsIgnoreCase(context.getProperty(PROP_DIGEST_AUTH).getValue())) {
if (!authUser.isEmpty()) {
final String authPass = trimToEmpty(context.getProperty(PROP_BASIC_AUTH_PASSWORD).getValue());
String credential = Credentials.basic(authUser, authPass);
requestBuilder.header("Authorization", credential);
} else {
oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider ->
requestBuilder.addHeader("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken())
);
}
}
// set the request method

View File

@ -21,6 +21,7 @@ import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.flowfile.attributes.CoreAttributes;
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.standard.http.FlowFileNamingStrategy;
import org.apache.nifi.reporting.InitializationException;
@ -34,6 +35,7 @@ import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.apache.nifi.web.util.ssl.SslContextUtils;
import org.mockito.Answers;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
@ -774,6 +776,99 @@ public class InvokeHTTPTest {
);
}
@Test
public void testValidWhenOAuth2Set() throws Exception {
// GIVEN
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
runner.enableControllerService(oauth2AccessTokenProvider);
setUrlProperty();
// WHEN
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
// THEN
runner.assertValid();
}
@Test
public void testInvalidWhenOAuth2AndUserNameSet() throws Exception {
// GIVEN
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
runner.enableControllerService(oauth2AccessTokenProvider);
setUrlProperty();
// WHEN
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_USERNAME, "userName");
// THEN
runner.assertNotValid();
}
@Test
public void testInvalidWhenOAuth2AndPasswordSet() throws Exception {
// GIVEN
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
runner.enableControllerService(oauth2AccessTokenProvider);
setUrlProperty();
// WHEN
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
runner.setProperty(InvokeHTTP.PROP_BASIC_AUTH_PASSWORD, "password");
// THEN
runner.assertNotValid();
}
@Test
public void testOAuth2AuthorizationHeader() throws Exception {
// GIVEN
String accessToken = "access_token";
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
runner.enableControllerService(oauth2AccessTokenProvider);
setUrlProperty();
mockWebServer.enqueue(new MockResponse());
// WHEN
runner.setProperty(InvokeHTTP.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId);
runner.enqueue("unimportant");
runner.run();
// THEN
RecordedRequest recordedRequest = mockWebServer.takeRequest();
String actualAuthorizationHeader = recordedRequest.getHeader("Authorization");
assertEquals("Bearer " + accessToken, actualAuthorizationHeader);
}
private void setUrlProperty() {
runner.setProperty(InvokeHTTP.PROP_URL, getMockWebServerUrl());
}

View File

@ -17,53 +17,80 @@
package org.apache.nifi.oauth2;
import java.time.Duration;
import java.time.Instant;
public class AccessToken {
private String accessToken;
private String refreshToken;
private String tokenType;
private Integer expires;
private String scope;
private long expiresIn;
private String scopes;
private Long fetchTime;
private final Instant fetchTime;
public AccessToken(String accessToken,
String refreshToken,
String tokenType,
Integer expires,
String scope) {
public static final int EXPIRY_MARGIN = 5000;
public AccessToken() {
this.fetchTime = Instant.now();
}
public AccessToken(String accessToken, String refreshToken, String tokenType, long expiresIn, String scopes) {
this();
this.accessToken = accessToken;
this.tokenType = tokenType;
this.refreshToken = refreshToken;
this.expires = expires;
this.scope = scope;
this.fetchTime = System.currentTimeMillis();
this.tokenType = tokenType;
this.expiresIn = expiresIn;
this.scopes = scopes;
}
public String getAccessToken() {
return accessToken;
}
public void setAccessToken(String accessToken) {
this.accessToken = accessToken;
}
public String getRefreshToken() {
return refreshToken;
}
public void setRefreshToken(String refreshToken) {
this.refreshToken = refreshToken;
}
public String getTokenType() {
return tokenType;
}
public Integer getExpires() {
return expires;
public void setTokenType(String tokenType) {
this.tokenType = tokenType;
}
public String getScope() {
return scope;
public long getExpiresIn() {
return expiresIn;
}
public Long getFetchTime() {
public void setExpiresIn(long expiresIn) {
this.expiresIn = expiresIn;
}
public String getScopes() {
return scopes;
}
public void setScopes(String scopes) {
this.scopes = scopes;
}
public Instant getFetchTime() {
return fetchTime;
}
public boolean isExpired() {
return System.currentTimeMillis() >= ( fetchTime + (expires * 1000) );
boolean expired = Duration.between(Instant.now(), fetchTime.plusSeconds(expiresIn - EXPIRY_MARGIN)).isNegative();
return expired;
}
}

View File

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.oauth2;
import org.apache.nifi.controller.ControllerService;
/**
* Controller service that provides OAuth2 access details
*/
public interface OAuth2AccessTokenProvider extends ControllerService {
/**
* @return A valid access token (refreshed automatically if needed) and additional metadata (provided by the OAuth2 access server)
*/
AccessToken getAccessDetails();
}

View File

@ -29,7 +29,10 @@ import java.util.List;
/**
* Interface for defining a credential-providing controller service for oauth2 processes.
*
* @deprecated use {@link OAuth2AccessTokenProvider} instead
*/
@Deprecated
public interface OAuth2TokenProvider extends ControllerService {
PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
.name("oauth2-ssl-context")

View File

@ -30,6 +30,7 @@ import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.DeprecationNotice;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnEnabled;
import org.apache.nifi.components.PropertyDescriptor;
@ -39,6 +40,8 @@ import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.ssl.SSLContextService;
import org.apache.nifi.util.StringUtils;
@Deprecated
@DeprecationNotice(alternatives = {StandardOauth2AccessTokenProvider.class})
@Tags({"oauth2", "provider", "authorization" })
@CapabilityDescription("This controller service provides a way of working with access and refresh tokens via the " +
"password and client_credential grant flows in the OAuth2 specification. It is meant to provide a way for components " +

View File

@ -0,0 +1,300 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.oauth2;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import okhttp3.FormBody;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.annotation.lifecycle.OnDisabled;
import org.apache.nifi.annotation.lifecycle.OnEnabled;
import org.apache.nifi.components.AllowableValue;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationContext;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.components.Validator;
import org.apache.nifi.controller.AbstractControllerService;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.ssl.SSLContextService;
import javax.net.ssl.SSLContext;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@Tags({"oauth2", "provider", "authorization", "access token", "http"})
@CapabilityDescription("Provides OAuth 2.0 access tokens that can be used as Bearer authorization header in HTTP requests." +
" Uses Resource Owner Password Credentials Grant.")
public class StandardOauth2AccessTokenProvider extends AbstractControllerService implements OAuth2AccessTokenProvider {
public static final PropertyDescriptor AUTHORIZATION_SERVER_URL = new PropertyDescriptor.Builder()
.name("authorization-server-url")
.displayName("Authorization Server URL")
.description("The URL of the authorization server that issues access tokens.")
.required(true)
.addValidator(StandardValidators.URL_VALIDATOR)
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
.build();
public static AllowableValue RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE = new AllowableValue(
"password",
"User Password",
"Resource Owner Password Credentials Grant. Used to access resources available to users. Requires username and password and usually Client ID and Client Secret"
);
public static AllowableValue CLIENT_CREDENTIALS_GRANT_TYPE = new AllowableValue(
"client_credentials",
"Client Credentials",
"Client Credentials Grant. Used to access resources available to clients. Requires Client ID and Client Secret"
);
public static final PropertyDescriptor GRANT_TYPE = new PropertyDescriptor.Builder()
.name("grant-type")
.displayName("Grant Type")
.description("The OAuth2 Grant Type to be used when acquiring an access token.")
.required(true)
.allowableValues(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE, CLIENT_CREDENTIALS_GRANT_TYPE)
.defaultValue(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())
.build();
public static final PropertyDescriptor USERNAME = new PropertyDescriptor.Builder()
.name("service-user-name")
.displayName("Username")
.description("Username on the service that is being accessed.")
.dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
.required(true)
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
.build();
public static final PropertyDescriptor PASSWORD = new PropertyDescriptor.Builder()
.name("service-password")
.displayName("Password")
.description("Password for the username on the service that is being accessed.")
.dependsOn(GRANT_TYPE, RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE)
.required(true)
.sensitive(true)
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
.build();
public static final PropertyDescriptor CLIENT_ID = new PropertyDescriptor.Builder()
.name("client-id")
.displayName("Client ID")
.required(false)
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
.expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY)
.build();
public static final PropertyDescriptor CLIENT_SECRET = new PropertyDescriptor.Builder()
.name("client-secret")
.displayName("Client secret")
.dependsOn(CLIENT_ID)
.required(true)
.sensitive(true)
.addValidator(StandardValidators.NON_BLANK_VALIDATOR)
.build();
public static final PropertyDescriptor SSL_CONTEXT = new PropertyDescriptor.Builder()
.name("ssl-context-service")
.displayName("SSL Context Servuce")
.addValidator(Validator.VALID)
.identifiesControllerService(SSLContextService.class)
.required(false)
.build();
private static final List<PropertyDescriptor> PROPERTIES = Collections.unmodifiableList(Arrays.asList(
AUTHORIZATION_SERVER_URL,
GRANT_TYPE,
USERNAME,
PASSWORD,
CLIENT_ID,
CLIENT_SECRET,
SSL_CONTEXT
));
public static final ObjectMapper ACCESS_DETAILS_MAPPER = new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
private volatile String authorizationServerUrl;
private volatile OkHttpClient httpClient;
private volatile String grantType;
private volatile String username;
private volatile String password;
private volatile String clientId;
private volatile String clientSecret;
private volatile AccessToken accessDetails;
@Override
public List<PropertyDescriptor> getSupportedPropertyDescriptors() {
return PROPERTIES;
}
@OnEnabled
public void onEnabled(ConfigurationContext context) {
authorizationServerUrl = context.getProperty(AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue();
httpClient = createHttpClient(context);
grantType = context.getProperty(GRANT_TYPE).getValue();
username = context.getProperty(USERNAME).evaluateAttributeExpressions().getValue();
password = context.getProperty(PASSWORD).getValue();
clientId = context.getProperty(CLIENT_ID).evaluateAttributeExpressions().getValue();
clientSecret = context.getProperty(CLIENT_SECRET).getValue();
}
@OnDisabled
public void onDisabled() {
accessDetails = null;
}
@Override
protected Collection<ValidationResult> customValidate(ValidationContext validationContext) {
final List<ValidationResult> validationResults = new ArrayList<>(super.customValidate(validationContext));
if (
validationContext.getProperty(GRANT_TYPE).getValue().equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())
&& !validationContext.getProperty(CLIENT_ID).isSet()
) {
validationResults.add(new ValidationResult.Builder().subject(CLIENT_ID.getDisplayName())
.valid(false)
.explanation(String.format(
"When '%s' is set to '%s', '%s' is required",
GRANT_TYPE.getDisplayName(),
CLIENT_CREDENTIALS_GRANT_TYPE.getDisplayName(),
CLIENT_ID.getDisplayName())
)
.build());
}
return validationResults;
}
protected OkHttpClient createHttpClient(ConfigurationContext context) {
OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder();
SSLContextService sslService = context.getProperty(SSL_CONTEXT).asControllerService(SSLContextService.class);
if (sslService != null) {
final X509TrustManager trustManager = sslService.createTrustManager();
SSLContext sslContext = sslService.createContext();
clientBuilder.sslSocketFactory(sslContext.getSocketFactory(), trustManager);
}
return clientBuilder.build();
}
@Override
public AccessToken getAccessDetails() {
if (this.accessDetails == null) {
acquireAccessDetails();
} else if (this.accessDetails.isExpired()) {
if (this.accessDetails.getRefreshToken() == null) {
acquireAccessDetails();
} else {
try {
refreshAccessDetails();
} catch (Exception e) {
getLogger().info("Couldn't refresh access token", e);
acquireAccessDetails();
}
}
}
return accessDetails;
}
private void acquireAccessDetails() {
getLogger().debug("Getting a new access token");
FormBody.Builder acquireTokenBuilder = new FormBody.Builder();
if (grantType.equals(RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue())) {
acquireTokenBuilder.add("grant_type", "password")
.add("username", username)
.add("password", password);
} else if (grantType.equals(CLIENT_CREDENTIALS_GRANT_TYPE.getValue())) {
acquireTokenBuilder.add("grant_type", "client_credentials");
}
if (clientId != null) {
acquireTokenBuilder.add("client_id", clientId);
acquireTokenBuilder.add("client_secret", clientSecret);
}
RequestBody acquireTokenRequestBody = acquireTokenBuilder.build();
Request acquireTokenRequest = new Request.Builder()
.url(authorizationServerUrl)
.post(acquireTokenRequestBody)
.build();
this.accessDetails = getAccessDetails(acquireTokenRequest);
}
private void refreshAccessDetails() {
getLogger().debug("Refreshing access token");
FormBody.Builder refreshTokenBuilder = new FormBody.Builder()
.add("grant_type", "refresh_token")
.add("refresh_token", this.accessDetails.getRefreshToken());
if (clientId != null) {
refreshTokenBuilder.add("client_id", clientId);
refreshTokenBuilder.add("client_secret", clientSecret);
}
RequestBody refreshTokenRequestBody = refreshTokenBuilder.build();
Request refreshRequest = new Request.Builder()
.url(authorizationServerUrl)
.post(refreshTokenRequestBody)
.build();
this.accessDetails = getAccessDetails(refreshRequest);
}
private AccessToken getAccessDetails(Request newRequest) {
try {
Response response = httpClient.newCall(newRequest).execute();
String responseBody = response.body().string();
if (response.code() != 200) {
getLogger().error(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", response.code(), responseBody));
throw new ProcessException(String.format("OAuth2 access token request failed [HTTP %d]", response.code()));
}
AccessToken accessDetails = ACCESS_DETAILS_MAPPER.readValue(responseBody, AccessToken.class);
return accessDetails;
} catch (IOException e) {
throw new UncheckedIOException("OAuth2 access token request failed", e);
}
}
}

View File

@ -13,3 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
org.apache.nifi.oauth2.OAuth2TokenProviderImpl
org.apache.nifi.oauth2.StandardOauth2AccessTokenProvider

View File

@ -107,7 +107,7 @@ public class OAuth2TokenProviderImplTest {
private void assertAccessTokenFound(final AccessToken accessToken) {
assertNotNull(accessToken);
assertEquals("access token", accessToken.getAccessToken());
assertEquals(300, accessToken.getExpires().intValue());
assertEquals(5300, accessToken.getExpiresIn());
assertEquals("BEARER", accessToken.getTokenType());
assertFalse(accessToken.isExpired());
}
@ -117,7 +117,7 @@ public class OAuth2TokenProviderImplTest {
token.put("access_token", "access token");
token.put("refresh_token", "refresh token");
token.put("token_type", "BEARER");
token.put("expires_in", 300);
token.put("expires_in", 5300);
token.put("scope", "test scope");
final String accessToken = new ObjectMapper().writeValueAsString(token);
mockWebServer.enqueue(new MockResponse().setResponseCode(200).addHeader("Content-Type", "application/json").setBody(accessToken));

View File

@ -0,0 +1,411 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.oauth2;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processor.Processor;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.util.NoOpProcessor;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class StandardOauth2AccessTokenProviderTest {
private static final String AUTHORIZATION_SERVER_URL = "http://authorizationServerUrl";
private static final String USERNAME = "username";
private static final String PASSWORD = "password";
private static final String CLIENT_ID = "clientId";
private static final String CLIENT_SECRET = "clientSecret";
private StandardOauth2AccessTokenProvider testSubject;
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private OkHttpClient mockHttpClient;
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private ConfigurationContext mockContext;
@Mock
private ComponentLog mockLogger;
@Captor
private ArgumentCaptor<String> debugCaptor;
@Captor
private ArgumentCaptor<String> errorCaptor;
@Captor
private ArgumentCaptor<Throwable> throwableCaptor;
@Before
public void setUp() throws Exception {
MockitoAnnotations.initMocks(this);
testSubject = new StandardOauth2AccessTokenProvider() {
@Override
protected OkHttpClient createHttpClient(ConfigurationContext context) {
return mockHttpClient;
}
@Override
protected ComponentLog getLogger() {
return mockLogger;
}
};
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.GRANT_TYPE).getValue()).thenReturn(StandardOauth2AccessTokenProvider.RESOURCE_OWNER_PASSWORD_CREDENTIALS_GRANT_TYPE.getValue());
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL).evaluateAttributeExpressions().getValue()).thenReturn(AUTHORIZATION_SERVER_URL);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.USERNAME).evaluateAttributeExpressions().getValue()).thenReturn(USERNAME);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.PASSWORD).getValue()).thenReturn(PASSWORD);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_ID).evaluateAttributeExpressions().getValue()).thenReturn(CLIENT_ID);
when(mockContext.getProperty(StandardOauth2AccessTokenProvider.CLIENT_SECRET).getValue()).thenReturn(CLIENT_SECRET);
testSubject.onEnabled(mockContext);
}
@Test
public void testInvalidWhenClientCredentialsGrantTypeSetWithoutClientId() throws Exception {
// GIVEN
Processor processor = new NoOpProcessor();
TestRunner runner = TestRunners.newTestRunner(processor);
runner.addControllerService("testSubject", testSubject);
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
// WHEN
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
// THEN
runner.assertNotValid(testSubject);
}
@Test
public void testValidWhenClientCredentialsGrantTypeSetWithClientId() throws Exception {
// GIVEN
Processor processor = new NoOpProcessor();
TestRunner runner = TestRunners.newTestRunner(processor);
runner.addControllerService("testSubject", testSubject);
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.AUTHORIZATION_SERVER_URL, "http://unimportant");
// WHEN
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.GRANT_TYPE, StandardOauth2AccessTokenProvider.CLIENT_CREDENTIALS_GRANT_TYPE);
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_ID, "clientId");
runner.setProperty(testSubject, StandardOauth2AccessTokenProvider.CLIENT_SECRET, "clientSecret");
// THEN
runner.assertValid(testSubject);
}
@Test
public void testAcquireNewToken() throws Exception {
String accessTokenValue = "access_token_value";
// GIVEN
Response response = buildResponse(
200,
"{ \"access_token\":\"" + accessTokenValue + "\" }"
);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response);
// WHEN
String actual = testSubject.getAccessDetails().getAccessToken();
// THEN
assertEquals(accessTokenValue, actual);
}
@Test
public void testRefreshToken() throws Exception {
// GIVEN
String firstToken = "first_token";
String expectedToken = "second_token";
Response response1 = buildResponse(
200,
"{ \"access_token\":\"" + firstToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
);
Response response2 = buildResponse(
200,
"{ \"access_token\":\"" + expectedToken + "\" }"
);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenReturn(response1, response2);
// WHEN
testSubject.getAccessDetails();
String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
assertEquals(expectedToken, actualToken);
}
@Test
public void testIOExceptionDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
String refreshErrorMessage = "refresh_error";
String acquireErrorMessage = "acquire_error";
AtomicInteger callCounter = new AtomicInteger(0);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
callCounter.incrementAndGet();
if (callCounter.get() == 1) {
return buildSuccessfulInitResponse();
} else if (callCounter.get() == 2) {
throw new IOException(refreshErrorMessage);
} else if (callCounter.get() == 3) {
throw new IOException(acquireErrorMessage);
}
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
UncheckedIOException actualException = assertThrows(
UncheckedIOException.class,
() -> testSubject.getAccessDetails()
);
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
checkError(new UncheckedIOException("OAuth2 access token request failed", new IOException(acquireErrorMessage)), actualException);
}
@Test
public void testIOExceptionDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
// GIVEN
String refreshErrorMessage = "refresh_error";
String expectedToken = "expected_token";
Response successfulAcquireResponse = buildResponse(
200,
"{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
);
AtomicInteger callCounter = new AtomicInteger(0);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
callCounter.incrementAndGet();
if (callCounter.get() == 1) {
return buildSuccessfulInitResponse();
} else if (callCounter.get() == 2) {
throw new IOException(refreshErrorMessage);
} else if (callCounter.get() == 3) {
return successfulAcquireResponse;
}
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new UncheckedIOException("OAuth2 access token request failed", new IOException(refreshErrorMessage)));
assertEquals(expectedToken, actualToken);
}
@Test
public void testHTTPErrorDuringRefreshAndSubsequentAcquire() throws Exception {
// GIVEN
String errorRefreshResponseBody = "{ \"error_response\":\"refresh_error\" }";
String errorAcquireResponseBody = "{ \"error_response\":\"acquire_error\" }";
Response errorRefreshResponse = buildResponse(500, errorRefreshResponseBody);
Response errorAcquireResponse = buildResponse(503, errorAcquireResponseBody);
AtomicInteger callCounter = new AtomicInteger(0);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
callCounter.incrementAndGet();
if (callCounter.get() == 1) {
return buildSuccessfulInitResponse();
} else if (callCounter.get() == 2) {
return errorRefreshResponse;
} else if (callCounter.get() == 3) {
return errorAcquireResponse;
}
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
List<String> expectedLoggedError = Arrays.asList(
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, errorRefreshResponseBody),
String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 503, errorAcquireResponseBody)
);
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
ProcessException actualException = assertThrows(
ProcessException.class,
() -> testSubject.getAccessDetails()
);
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
checkError(new ProcessException("OAuth2 access token request failed [HTTP 503]"), actualException);
}
@Test
public void testHTTPErrorDuringRefreshSuccessfulSubsequentAcquire() throws Exception {
// GIVEN
String expectedRefreshErrorResponse = "{ \"error_response\":\"refresh_error\" }";
String expectedToken = "expected_token";
Response errorRefreshResponse = buildResponse(500, expectedRefreshErrorResponse);
Response successfulAcquireResponse = buildResponse(
200,
"{ \"access_token\":\"" + expectedToken + "\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
);
AtomicInteger callCounter = new AtomicInteger(0);
when(mockHttpClient.newCall(any(Request.class)).execute()).thenAnswer(invocation -> {
callCounter.incrementAndGet();
if (callCounter.get() == 1) {
return buildSuccessfulInitResponse();
} else if (callCounter.get() == 2) {
return errorRefreshResponse;
} else if (callCounter.get() == 3) {
return successfulAcquireResponse;
}
throw new IllegalStateException("Test improperly defined mock HTTP responses.");
});
List<String> expectedLoggedError = Arrays.asList(String.format("OAuth2 access token request failed [HTTP %d], response:%n%s", 500, expectedRefreshErrorResponse));
// Get a good accessDetails so we can have a refresh a second time
testSubject.getAccessDetails();
// WHEN
String actualToken = testSubject.getAccessDetails().getAccessToken();
// THEN
checkLoggedDebugWhenRefreshFails();
checkLoggedRefreshError(new ProcessException("OAuth2 access token request failed [HTTP 500]"));
checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(expectedLoggedError);
assertEquals(expectedToken, actualToken);
}
private Response buildSuccessfulInitResponse() {
return buildResponse(
200,
"{ \"access_token\":\"exists_but_value_irrelevant\", \"expires_in\":\"0\", \"refresh_token\":\"not_checking_in_this_test\" }"
);
}
private Response buildResponse(int code, String body) {
return new Response.Builder()
.request(new Request.Builder()
.url("http://unimportant_but_required")
.build()
)
.protocol(Protocol.HTTP_2)
.message("unimportant_but_required")
.code(code)
.body(ResponseBody.create(
body.getBytes(),
MediaType.parse("application/json"))
)
.build();
}
private void checkLoggedDebugWhenRefreshFails() {
verify(mockLogger, times(3)).debug(debugCaptor.capture());
List<String> actualDebugMessages = debugCaptor.getAllValues();
assertEquals(
Arrays.asList("Getting a new access token", "Refreshing access token", "Getting a new access token"),
actualDebugMessages
);
}
private void checkedLoggedErrorWhenRefreshReturnsBadHTTPResponse(List<String> expectedLoggedError) {
verify(mockLogger, times(expectedLoggedError.size())).error(errorCaptor.capture());
List<String> actualLoggedError = errorCaptor.getAllValues();
assertEquals(expectedLoggedError, actualLoggedError);
}
private void checkLoggedRefreshError(Throwable expectedRefreshError) {
verify(mockLogger).info(eq("Couldn't refresh access token"), throwableCaptor.capture());
Throwable actualRefreshError = throwableCaptor.getValue();
checkError(expectedRefreshError, actualRefreshError);
}
private void checkError(Throwable expectedError, Throwable actualError) {
assertEquals(expectedError.getClass(), actualError.getClass());
assertEquals(expectedError.getMessage(), actualError.getMessage());
if (expectedError.getCause() != null || actualError.getCause() != null) {
assertEquals(expectedError.getCause().getClass(), actualError.getCause().getClass());
assertEquals(expectedError.getCause().getMessage(), actualError.getCause().getMessage());
}
}
}