NIFI-12300 Add OAuth2 Support to RestLookupService (#8462)

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Greg Foreman 2024-04-26 00:50:42 -04:00 committed by GitHub
parent 0311bff213
commit 9098c013f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 143 additions and 9 deletions

View File

@ -148,6 +148,10 @@
<artifactId>nifi-schema-registry-service-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-oauth2-provider-api</artifactId>
</dependency>
</dependencies>
<build>
<plugins>

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.lookup;
import org.apache.nifi.components.DescribedValue;
public enum AuthenticationStrategy implements DescribedValue {
NONE("None","No Authentication"),
BASIC("Basic", "Basic Authentication"),
OAUTH2("OAuth2", "OAuth2 Authentication");
private final String displayName;
private final String description;
AuthenticationStrategy(final String displayName, final String description) {
this.displayName = displayName;
this.description = description;
}
@Override
public String getValue() {
return name();
}
@Override
public String getDisplayName() {
return displayName;
}
@Override
public String getDescription() {
return description;
}
}

View File

@ -41,6 +41,8 @@ import org.apache.nifi.components.Validator;
import org.apache.nifi.controller.AbstractControllerService;
import org.apache.nifi.controller.ConfigurationContext;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.proxy.ProxyConfiguration;
import org.apache.nifi.proxy.ProxyConfigurationService;
@ -125,11 +127,30 @@ public class RestLookupService extends AbstractControllerService implements Reco
.identifiesControllerService(SSLContextService.class)
.build();
public static final PropertyDescriptor AUTHENTICATION_STRATEGY = new PropertyDescriptor.Builder()
.name("rest-lookup-authentication-strategy")
.displayName("Authentication Strategy")
.description("Authentication strategy to use with REST service.")
.required(true)
.allowableValues(AuthenticationStrategy.class)
.defaultValue(AuthenticationStrategy.NONE)
.build();
public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder()
.name("rest-lookup-oauth2-access-token-provider")
.displayName("OAuth2 Access Token Provider")
.description("Enables managed retrieval of OAuth2 Bearer Token applied to HTTP requests using the Authorization Header.")
.identifiesControllerService(OAuth2AccessTokenProvider.class)
.required(true)
.dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.OAUTH2)
.build();
public static final PropertyDescriptor PROP_BASIC_AUTH_USERNAME = new PropertyDescriptor.Builder()
.name("rest-lookup-basic-auth-username")
.displayName("Basic Authentication Username")
.description("The username to be used by the client to authenticate against the Remote URL. Cannot include control characters (0-31), ':', or DEL (127).")
.required(false)
.dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC)
.expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
.addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x39\\x3b-\\x7e\\x80-\\xff]+$")))
.build();
@ -139,6 +160,7 @@ public class RestLookupService extends AbstractControllerService implements Reco
.displayName("Basic Authentication Password")
.description("The password to be used by the client to authenticate against the Remote URL.")
.required(false)
.dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC)
.sensitive(true)
.expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT)
.addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x7e\\x80-\\xff]+$")))
@ -150,6 +172,7 @@ public class RestLookupService extends AbstractControllerService implements Reco
.description("Whether to communicate with the website using Digest Authentication. 'Basic Authentication Username' and 'Basic Authentication Password' are used "
+ "for authentication.")
.required(false)
.dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC)
.defaultValue("false")
.allowableValues("true", "false")
.build();
@ -201,6 +224,8 @@ public class RestLookupService extends AbstractControllerService implements Reco
RECORD_PATH,
RESPONSE_HANDLING_STRATEGY,
SSL_CONTEXT_SERVICE,
AUTHENTICATION_STRATEGY,
OAUTH2_ACCESS_TOKEN_PROVIDER,
PROXY_CONFIGURATION_SERVICE,
PROP_BASIC_AUTH_USERNAME,
PROP_BASIC_AUTH_PASSWORD,
@ -225,6 +250,7 @@ public class RestLookupService extends AbstractControllerService implements Reco
private volatile String basicPass;
private volatile boolean isDigest;
private volatile ResponseHandlingStrategy responseHandlingStrategy;
private volatile Optional<OAuth2AccessTokenProvider> oauth2AccessTokenProviderOptional;
@OnEnabled
public void onEnabled(final ConfigurationContext context) {
@ -232,6 +258,14 @@ public class RestLookupService extends AbstractControllerService implements Reco
proxyConfigurationService = context.getProperty(PROXY_CONFIGURATION_SERVICE)
.asControllerService(ProxyConfigurationService.class);
if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) {
OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class);
oauth2AccessTokenProvider.getAccessDetails();
oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider);
} else {
oauth2AccessTokenProviderOptional = Optional.empty();
}
OkHttpClient.Builder builder = new OkHttpClient.Builder();
setAuthenticator(builder, context);
@ -363,6 +397,13 @@ public class RestLookupService extends AbstractControllerService implements Reco
}
}
@Override
public void migrateProperties(final PropertyConfiguration config) {
if (config.isPropertySet(PROP_BASIC_AUTH_USERNAME)) {
config.setProperty(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC.getValue());
}
}
protected void validateVerb(String method) throws LookupFailureException {
if (!VALID_VERBS.contains(method)) {
throw new LookupFailureException(String.format("%s is not a supported HTTP verb.", method));
@ -444,32 +485,38 @@ public class RestLookupService extends AbstractControllerService implements Reco
final MediaType mt = MediaType.parse(mimeType);
requestBody = RequestBody.create(body, mt);
}
Request.Builder request = new Request.Builder()
final Request.Builder request = new Request.Builder()
.url(endpoint);
switch (method) {
case "delete":
request = body != null ? request.delete(requestBody) : request.delete();
if (body != null) request.delete(requestBody); else request.delete();
break;
case "get":
request = request.get();
request.get();
break;
case "post":
request = request.post(requestBody);
request.post(requestBody);
break;
case "put":
request = request.put(requestBody);
request.put(requestBody);
break;
}
if (headers != null) {
for (Map.Entry<String, PropertyValue> header : headers.entrySet()) {
request = request.addHeader(header.getKey(), header.getValue().evaluateAttributeExpressions(context).getValue());
request.addHeader(header.getKey(), header.getValue().evaluateAttributeExpressions(context).getValue());
}
}
if (!basicUser.isEmpty() && !isDigest) {
String credential = Credentials.basic(basicUser, basicPass);
request = request.header("Authorization", credential);
if (!isDigest) {
if (!basicUser.isEmpty()) {
String credential = Credentials.basic(basicUser, basicPass);
request.header("Authorization", credential);
} else {
oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider ->
request.header("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken())
);
}
}
return request.build();

View File

@ -19,6 +19,7 @@ package org.apache.nifi.lookup;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.apache.nifi.oauth2.OAuth2AccessTokenProvider;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.serialization.RecordReader;
import org.apache.nifi.serialization.RecordReaderFactory;
@ -31,6 +32,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Answers;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
@ -51,6 +53,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@Timeout(10)
@ -197,6 +200,35 @@ class TestRestLookupService {
assertInstanceOf(IOException.class, exception.getCause());
}
@Test
void testOAuth2AuthorizationHeader() throws Exception {
String accessToken = "access_token";
String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId";
OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS);
when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId);
when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken);
runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider);
runner.enableControllerService(oauth2AccessTokenProvider);
runner.setProperty(RestLookupService.AUTHENTICATION_STRATEGY, AuthenticationStrategy.OAUTH2);
runner.setProperty(restLookupService, RestLookupService.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProvider.getIdentifier());
runner.enableControllerService(restLookupService);
when(recordReaderFactory.createRecordReader(any(), any(), anyLong(), any())).thenReturn(recordReader);
when(recordReader.nextRecord()).thenReturn(record);
mockWebServer.enqueue(new MockResponse());
final Optional<Record> recordFound = restLookupService.lookup(Collections.emptyMap());
assertTrue(recordFound.isPresent());
RecordedRequest recordedRequest = mockWebServer.takeRequest();
String actualAuthorizationHeader = recordedRequest.getHeader("Authorization");
assertEquals("Bearer " + accessToken, actualAuthorizationHeader);
}
private void assertRecordedRequestFound() throws InterruptedException {
final RecordedRequest request = mockWebServer.takeRequest();