diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml index 3926ae32a6..f4d43f0621 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml +++ b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/pom.xml @@ -148,6 +148,10 @@ nifi-schema-registry-service-api test + + org.apache.nifi + nifi-oauth2-provider-api + diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/AuthenticationStrategy.java b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/AuthenticationStrategy.java new file mode 100644 index 0000000000..854c47dd2c --- /dev/null +++ b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/AuthenticationStrategy.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.lookup; + +import org.apache.nifi.components.DescribedValue; + +public enum AuthenticationStrategy implements DescribedValue { + + NONE("None","No Authentication"), + BASIC("Basic", "Basic Authentication"), + OAUTH2("OAuth2", "OAuth2 Authentication"); + + private final String displayName; + private final String description; + + AuthenticationStrategy(final String displayName, final String description) { + this.displayName = displayName; + this.description = description; + } + + @Override + public String getValue() { + return name(); + } + + @Override + public String getDisplayName() { + return displayName; + } + + @Override + public String getDescription() { + return description; + } + +} diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java index aa3d0864b2..a319b1c761 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/main/java/org/apache/nifi/lookup/RestLookupService.java @@ -41,6 +41,8 @@ import org.apache.nifi.components.Validator; import org.apache.nifi.controller.AbstractControllerService; import org.apache.nifi.controller.ConfigurationContext; import org.apache.nifi.expression.ExpressionLanguageScope; +import org.apache.nifi.migration.PropertyConfiguration; +import org.apache.nifi.oauth2.OAuth2AccessTokenProvider; import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.proxy.ProxyConfiguration; import org.apache.nifi.proxy.ProxyConfigurationService; @@ -125,11 +127,30 @@ public class RestLookupService extends AbstractControllerService implements Reco .identifiesControllerService(SSLContextService.class) .build(); + public static final PropertyDescriptor AUTHENTICATION_STRATEGY = new PropertyDescriptor.Builder() + .name("rest-lookup-authentication-strategy") + .displayName("Authentication Strategy") + .description("Authentication strategy to use with REST service.") + .required(true) + .allowableValues(AuthenticationStrategy.class) + .defaultValue(AuthenticationStrategy.NONE) + .build(); + + public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder() + .name("rest-lookup-oauth2-access-token-provider") + .displayName("OAuth2 Access Token Provider") + .description("Enables managed retrieval of OAuth2 Bearer Token applied to HTTP requests using the Authorization Header.") + .identifiesControllerService(OAuth2AccessTokenProvider.class) + .required(true) + .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.OAUTH2) + .build(); + public static final PropertyDescriptor PROP_BASIC_AUTH_USERNAME = new PropertyDescriptor.Builder() .name("rest-lookup-basic-auth-username") .displayName("Basic Authentication Username") .description("The username to be used by the client to authenticate against the Remote URL. Cannot include control characters (0-31), ':', or DEL (127).") .required(false) + .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC) .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT) .addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x39\\x3b-\\x7e\\x80-\\xff]+$"))) .build(); @@ -139,6 +160,7 @@ public class RestLookupService extends AbstractControllerService implements Reco .displayName("Basic Authentication Password") .description("The password to be used by the client to authenticate against the Remote URL.") .required(false) + .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC) .sensitive(true) .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT) .addValidator(StandardValidators.createRegexMatchingValidator(Pattern.compile("^[\\x20-\\x7e\\x80-\\xff]+$"))) @@ -150,6 +172,7 @@ public class RestLookupService extends AbstractControllerService implements Reco .description("Whether to communicate with the website using Digest Authentication. 'Basic Authentication Username' and 'Basic Authentication Password' are used " + "for authentication.") .required(false) + .dependsOn(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC) .defaultValue("false") .allowableValues("true", "false") .build(); @@ -201,6 +224,8 @@ public class RestLookupService extends AbstractControllerService implements Reco RECORD_PATH, RESPONSE_HANDLING_STRATEGY, SSL_CONTEXT_SERVICE, + AUTHENTICATION_STRATEGY, + OAUTH2_ACCESS_TOKEN_PROVIDER, PROXY_CONFIGURATION_SERVICE, PROP_BASIC_AUTH_USERNAME, PROP_BASIC_AUTH_PASSWORD, @@ -225,6 +250,7 @@ public class RestLookupService extends AbstractControllerService implements Reco private volatile String basicPass; private volatile boolean isDigest; private volatile ResponseHandlingStrategy responseHandlingStrategy; + private volatile Optional oauth2AccessTokenProviderOptional; @OnEnabled public void onEnabled(final ConfigurationContext context) { @@ -232,6 +258,14 @@ public class RestLookupService extends AbstractControllerService implements Reco proxyConfigurationService = context.getProperty(PROXY_CONFIGURATION_SERVICE) .asControllerService(ProxyConfigurationService.class); + if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) { + OAuth2AccessTokenProvider oauth2AccessTokenProvider = context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class); + oauth2AccessTokenProvider.getAccessDetails(); + oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider); + } else { + oauth2AccessTokenProviderOptional = Optional.empty(); + } + OkHttpClient.Builder builder = new OkHttpClient.Builder(); setAuthenticator(builder, context); @@ -363,6 +397,13 @@ public class RestLookupService extends AbstractControllerService implements Reco } } + @Override + public void migrateProperties(final PropertyConfiguration config) { + if (config.isPropertySet(PROP_BASIC_AUTH_USERNAME)) { + config.setProperty(AUTHENTICATION_STRATEGY, AuthenticationStrategy.BASIC.getValue()); + } + } + protected void validateVerb(String method) throws LookupFailureException { if (!VALID_VERBS.contains(method)) { throw new LookupFailureException(String.format("%s is not a supported HTTP verb.", method)); @@ -444,32 +485,38 @@ public class RestLookupService extends AbstractControllerService implements Reco final MediaType mt = MediaType.parse(mimeType); requestBody = RequestBody.create(body, mt); } - Request.Builder request = new Request.Builder() + final Request.Builder request = new Request.Builder() .url(endpoint); switch (method) { case "delete": - request = body != null ? request.delete(requestBody) : request.delete(); + if (body != null) request.delete(requestBody); else request.delete(); break; case "get": - request = request.get(); + request.get(); break; case "post": - request = request.post(requestBody); + request.post(requestBody); break; case "put": - request = request.put(requestBody); + request.put(requestBody); break; } if (headers != null) { for (Map.Entry header : headers.entrySet()) { - request = request.addHeader(header.getKey(), header.getValue().evaluateAttributeExpressions(context).getValue()); + request.addHeader(header.getKey(), header.getValue().evaluateAttributeExpressions(context).getValue()); } } - if (!basicUser.isEmpty() && !isDigest) { - String credential = Credentials.basic(basicUser, basicPass); - request = request.header("Authorization", credential); + if (!isDigest) { + if (!basicUser.isEmpty()) { + String credential = Credentials.basic(basicUser, basicPass); + request.header("Authorization", credential); + } else { + oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider -> + request.header("Authorization", "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken()) + ); + } } return request.build(); diff --git a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java index 724f5b8df0..04901cc123 100644 --- a/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java +++ b/nifi-nar-bundles/nifi-standard-services/nifi-lookup-services-bundle/nifi-lookup-services/src/test/java/org/apache/nifi/lookup/TestRestLookupService.java @@ -19,6 +19,7 @@ package org.apache.nifi.lookup; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; +import org.apache.nifi.oauth2.OAuth2AccessTokenProvider; import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.serialization.RecordReader; import org.apache.nifi.serialization.RecordReaderFactory; @@ -31,6 +32,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -51,6 +53,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @Timeout(10) @@ -197,6 +200,35 @@ class TestRestLookupService { assertInstanceOf(IOException.class, exception.getCause()); } + @Test + void testOAuth2AuthorizationHeader() throws Exception { + String accessToken = "access_token"; + String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId"; + + OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS); + when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId); + when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(accessToken); + runner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider); + runner.enableControllerService(oauth2AccessTokenProvider); + + runner.setProperty(RestLookupService.AUTHENTICATION_STRATEGY, AuthenticationStrategy.OAUTH2); + runner.setProperty(restLookupService, RestLookupService.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProvider.getIdentifier()); + runner.enableControllerService(restLookupService); + + when(recordReaderFactory.createRecordReader(any(), any(), anyLong(), any())).thenReturn(recordReader); + when(recordReader.nextRecord()).thenReturn(record); + mockWebServer.enqueue(new MockResponse()); + + final Optional recordFound = restLookupService.lookup(Collections.emptyMap()); + assertTrue(recordFound.isPresent()); + + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + + String actualAuthorizationHeader = recordedRequest.getHeader("Authorization"); + assertEquals("Bearer " + accessToken, actualAuthorizationHeader); + + } + private void assertRecordedRequestFound() throws InterruptedException { final RecordedRequest request = mockWebServer.takeRequest();