NIFI-12240 Added Python Processors for Docs, ChatGPT, Chroma, and Pinecone

Created new python processors for text embeddings, inserting into Chroma, querying Chroma, querying ChatGPT, inserting into and querying Pinecone. Fixed some bugs in the Python framework. Added Python extensions to assembly. Also added ability to load dependencies from a requirements.txt as that was important for making the different vectorstore implementations play more nicely together.

Excluded nifi-python-extensions-bundle from GitHub build because it requires Maven to use unpack-resources goal, which will not work in GitHub because it uses mvn compile instead of mvn install

- ParseDocument
- ChunkDocument
- PromptChatGPT
- PutChroma
- PutPinecone
- QueryChroma
- QueryPinecone

NIFI-12195 Added support for requirements.txt to define Python dependencies

This closes #7894

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Mark Payne 2023-09-20 17:39:10 -04:00 committed by exceptionfactory
parent 945d8b54bc
commit 5bcad9eef3
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
37 changed files with 2366 additions and 110 deletions

View File

@ -56,6 +56,7 @@ env:
-pl -:nifi-system-test-suite
-pl -:nifi-nar-provider-assembly
-pl -:nifi-py4j-integration-tests
-pl -:nifi-python-extensions-bundle
MAVEN_VERIFY_COMMAND: >-
verify
--show-version

View File

@ -894,6 +894,13 @@ language governing permissions and limitations under the License. -->
<version>2.0.0-SNAPSHOT</version>
<type>nar</type>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-python-extensions-bundle</artifactId>
<version>2.0.0-SNAPSHOT</version>
<type>zip</type>
</dependency>
<!-- AspectJ library needed by the Java Agent used for native library loading (see bootstrap.conf) -->
<dependency>
<groupId>org.aspectj</groupId>

View File

@ -41,6 +41,27 @@
<exclude>org.aspectj:aspectjweaver</exclude>
</excludes>
</dependencySet>
<!-- Unpack Python extensions -->
<dependencySet>
<scope>runtime</scope>
<useProjectArtifact>false</useProjectArtifact>
<outputDirectory>./python/extensions</outputDirectory>
<directoryMode>0770</directoryMode>
<fileMode>0664</fileMode>
<useTransitiveFiltering>true</useTransitiveFiltering>
<includes>
<include>*:nifi-python-extensions-bundle</include>
</includes>
<unpack>true</unpack>
<unpackOptions>
<excludes>
<exclude>META-INF/</exclude>
<exclude>META-INF/**</exclude>
</excludes>
</unpackOptions>
</dependencySet>
</dependencySets>
</assembly>

View File

@ -779,6 +779,34 @@ public class StandardValidators {
};
}
public static Validator createNonNegativeFloatingPointValidator(final double maximum) {
return new Validator() {
@Override
public ValidationResult validate(final String subject, final String input, final ValidationContext context) {
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(input)) {
return new ValidationResult.Builder().subject(subject).input(input).explanation("Expression Language Present").valid(true).build();
}
String reason = null;
try {
final double doubleValue = Double.parseDouble(input);
if (doubleValue < 0) {
reason = "Value must be non-negative but was " + doubleValue;
}
final double maxPlusDelta = maximum + 0.00001D;
if (doubleValue < 0 || doubleValue > maxPlusDelta) {
reason = "Value must be between 0 and " + maximum + " but was " + doubleValue;
}
} catch (final NumberFormatException e) {
reason = "not a valid integer";
}
return new ValidationResult.Builder().subject(subject).input(input).explanation(reason).valid(reason == null).build();
}
};
}
//
//
// SPECIFIC VALIDATOR IMPLEMENTATIONS THAT CANNOT BE ANONYMOUS CLASSES

View File

@ -485,10 +485,33 @@ to pickup any changes seamlessly as soon as the Processor is started.
[[dependencies]]
== Adding Third-Party Dependencies
Python based Processors can be a single module, or they can be bundled together as a Python package. How you specify third-party dependencies depends on how
the Processor is packaged.
Third-party dependencies are defined for a Processor using the `dependencies` member of the `ProcessorDetails` inner class.
This is a list of Strings that indicate the PyPI modules that the Processor depends on. The format is the same format expected
by PyPI.
=== Package-level Dependencies
If one or more Processors are defined within a Python package, the package should define a `requirements.txt` file that declares all third-party dependencies
that are necessary for any Processor in the package. The file structure will then typically look like this:
----
my-python-package/
├── __init__.py
├── ProcessorA.py
├── ProcessorB.py
└── requirements.txt
----
In this way, all of the requirements will be loaded from the `requirements.txt` file once for the package. There will be no need to load the dependencies once for
ProcessorA and once for ProcessorB.
=== Processor-Level Dependencies
If your Processor is not a part of a Python package, its dependencies can be declared using the `dependencies` member of the `ProcessorDetails` inner class.
This is a list of Strings that indicate the PyPI modules that the Processor depends on. The format is the same format expected by PyPI.
This provides a convenience for declaring third-party dependencies without requiring that Processors be bundled into a package.
For example, to indicate that a Processor needs `pandas` installed, the implementation might
look like this:

View File

@ -1187,7 +1187,7 @@ public abstract class AbstractComponentNode implements ComponentNode {
public PropertyDescriptor getPropertyDescriptor(final String name) {
try (final NarCloseable narCloseable = NarCloseable.withComponentNarLoader(extensionManager, getComponent().getClass(), getComponent().getIdentifier())) {
final PropertyDescriptor propertyDescriptor = getComponent().getPropertyDescriptor(name);
if (propertyDescriptor.isDynamic() && sensitiveDynamicPropertyNames.get().contains(name)) {
if (propertyDescriptor.isDynamic() && isSensitiveDynamicProperty(name)) {
return new PropertyDescriptor.Builder().fromPropertyDescriptor(propertyDescriptor).sensitive(true).build();
} else {
return propertyDescriptor;

View File

@ -542,7 +542,7 @@ public final class StandardProcessScheduler implements ProcessScheduler {
try {
final Set<URL> additionalUrls = procNode.getAdditionalClasspathResources(procNode.getPropertyDescriptors());
flowController.getReloadComponent().reload(procNode, procNode.getProcessor().getClass().getName(), procNode.getBundleCoordinate(), additionalUrls);
flowController.getReloadComponent().reload(procNode, procNode.getCanonicalClassName(), procNode.getBundleCoordinate(), additionalUrls);
} catch (final ProcessorInstantiationException e) {
// This shouldn't happen because we already have been able to instantiate the processor before
LOG.error("Failed to replace instance of Processor for {} when terminating Processor", procNode);

View File

@ -160,6 +160,9 @@ public class PythonProcess {
final List<String> commands = new ArrayList<>();
commands.add(pythonCommand);
String pythonPath = pythonApiDirectory.getAbsolutePath();
if (processConfig.isDebugController() && "Controller".equals(componentId)) {
commands.add("-m");
commands.add("debugpy");
@ -167,6 +170,8 @@ public class PythonProcess {
commands.add(processConfig.getDebugHost() + ":" + processConfig.getDebugPort());
commands.add("--log-to");
commands.add(processConfig.getDebugLogsDirectory().getAbsolutePath());
pythonPath = pythonPath + File.pathSeparator + virtualEnvHome.getAbsolutePath();
}
commands.add(controllerPyFile.getAbsolutePath());
@ -175,7 +180,7 @@ public class PythonProcess {
processBuilder.environment().put("JAVA_PORT", String.valueOf(listeningPort));
processBuilder.environment().put("LOGS_DIR", pythonLogsDirectory.getAbsolutePath());
processBuilder.environment().put("ENV_HOME", virtualEnvHome.getAbsolutePath());
processBuilder.environment().put("PYTHONPATH", pythonApiDirectory.getAbsolutePath());
processBuilder.environment().put("PYTHONPATH", pythonPath);
processBuilder.environment().put("PYTHON_CMD", pythonCommandFile.getAbsolutePath());
processBuilder.environment().put("AUTH_TOKEN", authToken);
processBuilder.inheritIO();
@ -231,8 +236,8 @@ public class PythonProcess {
final String pythonCommand = processConfig.getPythonCommand();
final ProcessBuilder processBuilder = new ProcessBuilder(pythonCommand, "-m", "pip", "install", "--upgrade", "debugpy", "--target",
processConfig.getPythonWorkingDirectory().getAbsolutePath());
processBuilder.directory(virtualEnvHome.getParentFile());
virtualEnvHome.getAbsolutePath());
processBuilder.directory(virtualEnvHome);
final String command = String.join(" ", processBuilder.command());
logger.debug("Installing DebugPy to Virtual Env {} using command {}", virtualEnvHome, command);

View File

@ -17,8 +17,8 @@
package org.apache.nifi.python.processor;
import org.apache.nifi.annotation.behavior.DefaultRunDuration;
import org.apache.nifi.annotation.behavior.SupportsBatching;
import org.apache.nifi.annotation.behavior.InputRequirement;
import org.apache.nifi.annotation.behavior.InputRequirement.Requirement;
import org.apache.nifi.annotation.lifecycle.OnScheduled;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
@ -30,7 +30,7 @@ import py4j.Py4JNetworkException;
import java.util.Map;
import java.util.Optional;
@SupportsBatching(defaultDuration = DefaultRunDuration.TWENTY_FIVE_MILLIS)
@InputRequirement(Requirement.INPUT_REQUIRED)
public class FlowFileTransformProxy extends PythonProcessorProxy {
private final PythonProcessorBridge bridge;
@ -60,7 +60,7 @@ public class FlowFileTransformProxy extends PythonProcessorProxy {
return;
}
FlowFile transformed = session.create(original);
FlowFile transformed = session.clone(original);
final FlowFileTransformResult result;
try (final StandardInputFlowFile inputFlowFile = new StandardInputFlowFile(session, original)) {

View File

@ -17,6 +17,9 @@
package org.apache.nifi.python.processor;
import org.apache.nifi.annotation.behavior.DefaultRunDuration;
import org.apache.nifi.annotation.behavior.SupportsBatching;
import org.apache.nifi.annotation.behavior.SupportsSensitiveDynamicProperties;
import org.apache.nifi.annotation.lifecycle.OnScheduled;
import org.apache.nifi.annotation.lifecycle.OnStopped;
import org.apache.nifi.components.AsyncLoadedProcessor;
@ -36,6 +39,8 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
@SupportsBatching(defaultDuration = DefaultRunDuration.TWENTY_FIVE_MILLIS)
@SupportsSensitiveDynamicProperties
public abstract class PythonProcessorProxy extends AbstractProcessor implements AsyncLoadedProcessor {
private final PythonProcessorBridge bridge;
private volatile Set<Relationship> cachedRelationships = null;
@ -94,8 +99,8 @@ public abstract class PythonProcessorProxy extends AbstractProcessor implements
@Override
protected Collection<ValidationResult> customValidate(final ValidationContext validationContext) {
final Optional<PythonProcessorAdapter> optionalAdapter = bridge.getProcessorAdapter();
if (optionalAdapter.isEmpty()) {
final LoadState loadState = bridge.getLoadState();
if (loadState == LoadState.LOADING_PROCESSOR_CODE || loadState == LoadState.DOWNLOADING_DEPENDENCIES) {
return List.of(new ValidationResult.Builder()
.subject("Processor")
.explanation("Processor has not yet completed initialization")
@ -105,6 +110,16 @@ public abstract class PythonProcessorProxy extends AbstractProcessor implements
try {
reload();
final Optional<PythonProcessorAdapter> optionalAdapter = bridge.getProcessorAdapter();
if (optionalAdapter.isEmpty()) {
return List.of(new ValidationResult.Builder()
.subject("Processor")
.explanation("Processor has not yet completed initialization")
.valid(false)
.build());
}
return optionalAdapter.get().customValidate(validationContext);
} catch (final Exception e) {
getLogger().warn("Failed to perform validation for Python Processor {}; assuming invalid", this, e);
@ -166,11 +181,6 @@ public abstract class PythonProcessorProxy extends AbstractProcessor implements
this.cachedDynamicDescriptors = dynamicDescriptors;
}
@OnStopped
public void destroyCachedElements() {
this.cachedRelationships = null;
this.cachedDynamicDescriptors = null;
}
@Override
public Set<Relationship> getRelationships() {
@ -224,11 +234,20 @@ public abstract class PythonProcessorProxy extends AbstractProcessor implements
getLogger().info("Successfully reloaded Processor");
}
cachedPropertyDescriptors = null;
cachedRelationships = null;
supportsDynamicProperties = bridge.getProcessorAdapter()
.orElseThrow(() -> new IllegalStateException("Processor has not finished initializing"))
.isDynamicPropertySupported();
}
@Override
public void onPropertyModified(final PropertyDescriptor descriptor, final String oldValue, final String newValue) {
cachedPropertyDescriptors = null;
cachedRelationships = null;
super.onPropertyModified(descriptor, oldValue, newValue);
}
protected Set<Relationship> getImplicitRelationships() {
return implicitRelationships;
}

View File

@ -18,8 +18,8 @@
package org.apache.nifi.python.processor;
import org.apache.nifi.NullSuppression;
import org.apache.nifi.annotation.behavior.DefaultRunDuration;
import org.apache.nifi.annotation.behavior.SupportsBatching;
import org.apache.nifi.annotation.behavior.InputRequirement;
import org.apache.nifi.annotation.behavior.InputRequirement.Requirement;
import org.apache.nifi.annotation.lifecycle.OnScheduled;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.flowfile.FlowFile;
@ -57,7 +57,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
@SupportsBatching(defaultDuration = DefaultRunDuration.TWENTY_FIVE_MILLIS)
@InputRequirement(Requirement.INPUT_REQUIRED)
public class RecordTransformProxy extends PythonProcessorProxy {
private final PythonProcessorBridge bridge;
private volatile RecordTransform transform;

View File

@ -78,4 +78,9 @@ public class StandardInputFlowFile implements InputFlowFile, Closeable {
public Map<String, String> getAttributes() {
return flowFile.getAttributes();
}
@Override
public String toString() {
return "FlowFile[id=" + getAttribute("uuid") + ", filename=" + getAttribute("filename") + ", size=" + getSize() + "]";
}
}

View File

@ -177,7 +177,6 @@ public class PythonControllerInteractionIT {
.orElseThrow(() -> new RuntimeException("Could not find ConvertCsvToExcel"));
assertEquals("0.0.1-SNAPSHOT", convertCsvToExcel.getProcessorVersion());
assertNull(convertCsvToExcel.getPyPiPackageName());
assertEquals(new File("target/python/extensions/ConvertCsvToExcel.py").getAbsolutePath(),
new File(convertCsvToExcel.getSourceLocation()).getAbsolutePath());
}

View File

@ -32,31 +32,31 @@ class Logger:
def trace(self, msg, *args):
if self.min_level < LogLevel.DEBUG:
return
self.java_logger.trace(msg, self.__to_java_array(args))
self.java_logger.trace(str(msg), self.__to_java_array(args))
def debug(self, msg, *args):
if self.min_level < LogLevel.DEBUG:
return
self.java_logger.debug(msg, self.__to_java_array(args))
self.java_logger.debug(str(msg), self.__to_java_array(args))
def info(self, msg, *args):
if self.min_level < LogLevel.DEBUG:
return
self.java_logger.info(msg, self.__to_java_array(args))
self.java_logger.info(str(msg), self.__to_java_array(args))
def warn(self, msg, *args):
if self.min_level < LogLevel.DEBUG:
return
self.java_logger.warn(msg, self.__to_java_array(args))
self.java_logger.warn(str(msg), self.__to_java_array(args))
def error(self, msg, *args):
if self.min_level < LogLevel.DEBUG:
return
self.java_logger.error(msg, self.__to_java_array(args))
self.java_logger.error(str(msg), self.__to_java_array(args))
def __to_java_array(self, *args):
arg_array = JvmHolder.gateway.new_array(JvmHolder.jvm.java.lang.Object, len(args))
for i, arg in enumerate(args):
arg_array[i] = arg
arg_array[i] = None if arg is None else str(arg)
return arg_array

View File

@ -27,27 +27,27 @@ class ExpressionLanguageScope(Enum):
class StandardValidators:
__standard_validators__ = JvmHolder.jvm.org.apache.nifi.processor.util.StandardValidators
_standard_validators = JvmHolder.jvm.org.apache.nifi.processor.util.StandardValidators
ALWAYS_VALID = JvmHolder.jvm.org.apache.nifi.components.Validator.VALID
NON_EMPTY_VALIDATOR = __standard_validators__.NON_EMPTY_VALIDATOR
INTEGER_VALIDATOR = __standard_validators__.INTEGER_VALIDATOR
POSITIVE_INTEGER_VALIDATOR = __standard_validators__.POSITIVE_INTEGER_VALIDATOR
POSITIVE_LONG_VALIDATOR = __standard_validators__.POSITIVE_LONG_VALIDATOR
NON_NEGATIVE_INTEGER_VALIDATOR = __standard_validators__.NON_NEGATIVE_INTEGER_VALIDATOR
NUMBER_VALIDATOR = __standard_validators__.NUMBER_VALIDATOR
LONG_VALIDATOR = __standard_validators__.LONG_VALIDATOR
PORT_VALIDATOR = __standard_validators__.PORT_VALIDATOR
NON_EMPTY_EL_VALIDATOR = __standard_validators__.NON_EMPTY_EL_VALIDATOR
HOSTNAME_PORT_LIST_VALIDATOR = __standard_validators__.HOSTNAME_PORT_LIST_VALIDATOR
BOOLEAN_VALIDATOR = __standard_validators__.BOOLEAN_VALIDATOR
URL_VALIDATOR = __standard_validators__.URL_VALIDATOR
URI_VALIDATOR = __standard_validators__.URI_VALIDATOR
REGULAR_EXPRESSION_VALIDATOR = __standard_validators__.REGULAR_EXPRESSION_VALIDATOR
REGULAR_EXPRESSION_WITH_EL_VALIDATOR = __standard_validators__.REGULAR_EXPRESSION_WITH_EL_VALIDATOR
TIME_PERIOD_VALIDATOR = __standard_validators__.TIME_PERIOD_VALIDATOR
DATA_SIZE_VALIDATOR = __standard_validators__.DATA_SIZE_VALIDATOR
FILE_EXISTS_VALIDATOR = __standard_validators__.FILE_EXISTS_VALIDATOR
NON_EMPTY_VALIDATOR = _standard_validators.NON_EMPTY_VALIDATOR
INTEGER_VALIDATOR = _standard_validators.INTEGER_VALIDATOR
POSITIVE_INTEGER_VALIDATOR = _standard_validators.POSITIVE_INTEGER_VALIDATOR
POSITIVE_LONG_VALIDATOR = _standard_validators.POSITIVE_LONG_VALIDATOR
NON_NEGATIVE_INTEGER_VALIDATOR = _standard_validators.NON_NEGATIVE_INTEGER_VALIDATOR
NUMBER_VALIDATOR = _standard_validators.NUMBER_VALIDATOR
LONG_VALIDATOR = _standard_validators.LONG_VALIDATOR
PORT_VALIDATOR = _standard_validators.PORT_VALIDATOR
NON_EMPTY_EL_VALIDATOR = _standard_validators.NON_EMPTY_EL_VALIDATOR
HOSTNAME_PORT_LIST_VALIDATOR = _standard_validators.HOSTNAME_PORT_LIST_VALIDATOR
BOOLEAN_VALIDATOR = _standard_validators.BOOLEAN_VALIDATOR
URL_VALIDATOR = _standard_validators.URL_VALIDATOR
URI_VALIDATOR = _standard_validators.URI_VALIDATOR
REGULAR_EXPRESSION_VALIDATOR = _standard_validators.REGULAR_EXPRESSION_VALIDATOR
REGULAR_EXPRESSION_WITH_EL_VALIDATOR = _standard_validators.REGULAR_EXPRESSION_WITH_EL_VALIDATOR
TIME_PERIOD_VALIDATOR = _standard_validators.TIME_PERIOD_VALIDATOR
DATA_SIZE_VALIDATOR = _standard_validators.DATA_SIZE_VALIDATOR
FILE_EXISTS_VALIDATOR = _standard_validators.FILE_EXISTS_VALIDATOR
@ -332,7 +332,7 @@ class ProcessContext:
def getProperty(self, descriptor):
property_name = descriptor if isinstance(descriptor, str) else descriptor.name
return self.property_values[property_name]
return self.property_values.get(property_name)
def getProperties(self):
return self.descriptor_value_map
@ -389,7 +389,7 @@ class PythonPropertyValue:
def asBoolean(self):
if self.value is None:
return None
return bool(self.value)
return self.value.lower() == 'true'
def asFloat(self):
if self.value is None:

View File

@ -35,11 +35,6 @@ public interface PythonProcessorDetails {
*/
String getSourceLocation();
/**
* @return the name of the Python Package Index (PyPi) package, or <code>null</code> if it is not available
*/
String getPyPiPackageName();
/**
* @return the Processor's capability description
*/

View File

@ -13,12 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import ExtensionManager
import os
from concurrent.futures import ThreadPoolExecutor
from py4j.java_gateway import JavaGateway, CallbackServerParameters, GatewayParameters
import PythonProcessorAdapter
import ExtensionManager
# We do not use ThreadPoolExecutor, but it must be kept here. Python introduced a bug in 3.9 that causes Exceptions to be raised
# incorrectly in multi-threaded applications (https://bugs.python.org/issue42647). This works around the bug.
# What is actually necessary is to import ThreadPoolExecutor.
# Unfortunately, IntelliJ often likes to cleanup the unused import. So we assign a bogus variable just so
# that we have some reference to ThreadPoolExecutor in order to prevent the IDE from cleaning up the import
threadpool_attrs = dir(ThreadPoolExecutor)
# Initialize logging
logger = logging.getLogger("org.apache.nifi.py4j.Controller")
@ -104,6 +113,19 @@ if __name__ == "__main__":
python_port = gateway.get_callback_server().get_listening_port()
logger.info("Listening for requests from Java side using Python Port {}, communicating with Java on port {}".format(python_port, java_port) )
# Initialize the JvmHolder class with the gateway jvm.
# This must be done before executing the module to ensure that the nifiapi module
# is able to access the JvmHolder.jvm variable. This enables the nifiapi.properties.StandardValidators, etc. to be used
# However, we have to delay the import until this point, rather than adding it to the top of the ExtensionManager class
# because we need to ensure that we've fetched the appropriate dependencies for the pyenv environment for the extension point.
from nifiapi.__jvm__ import JvmHolder
JvmHolder.jvm = gateway.jvm
JvmHolder.gateway = gateway
# We need to import PythonProcessorAdapter but cannot import it at the top of the class because we must first initialize the Gateway,
# since there are statically defined objects in the file that contains PythonProcessorAdapter, and those statically defined objects require the Gateway.
import PythonProcessorAdapter
# Notify the Java side of the port that Python is listening on
gateway.java_gateway_server.resetCallbackClient(
gateway.java_gateway_server.getCallbackClient().getAddress(),

View File

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ast
import importlib
import sys
import importlib.util # Note requires Python 3.4+
import inspect
import logging
import subprocess
import ast
import os
import pkgutil
import subprocess
import sys
from pathlib import Path
logger = logging.getLogger("org.apache.nifi.py4j.ExtensionManager")
@ -44,7 +44,13 @@ class ExtensionDetails:
class Java:
implements = ['org.apache.nifi.python.PythonProcessorDetails']
def __init__(self, gateway, type, interfaces, version='Unknown', dependencies=None, source_location=None, package_name=None, description=None, tags=None):
def __init__(self, gateway, type, interfaces,
version='Unknown',
dependencies=None,
source_location=None,
description=None,
tags=None):
self.gateway = gateway
self.type = type
@ -60,7 +66,6 @@ class ExtensionDetails:
self.version = version
self.dependencies = dependencies
self.source_location = source_location
self.package_name = package_name
self.description = description
self.tags = tags
@ -73,9 +78,6 @@ class ExtensionDetails:
def getSourceLocation(self):
return self.source_location
def getPyPiPackageName(self):
return self.package_name
def getDependencies(self):
list = self.gateway.jvm.java.util.ArrayList()
for dep in self.dependencies:
@ -180,7 +182,8 @@ class ExtensionManager:
# Delete the file that tells us that the dependencies have been downloaded. We do this only when reloading a processor
# because we want to ensure that download any new dependencies
completion_marker_file = self.__get_download_complete_marker_file(work_dir, processor_type, version)
details = self.processor_details[id]
completion_marker_file = self.__get_download_complete_marker_file(work_dir, details)
if os.path.exists(completion_marker_file):
os.remove(completion_marker_file)
@ -188,7 +191,6 @@ class ExtensionManager:
self.__gather_extension_details(module_file, work_dir)
# Reload the processor class itself
details = self.processor_details[id]
processor_class = self.__load_extension_module(module_file, details.local_dependencies)
# Update our cache so that when the processor is created again, the new class will be used
@ -232,17 +234,21 @@ class ExtensionManager:
def __discover_extensions_from_paths(self, paths, work_dir, require_nifi_prefix):
for finder, name, ispkg in pkgutil.iter_modules(paths):
if not require_nifi_prefix or name.startswith('nifi_'):
module_file = '<Unknown Module File>'
try:
module = finder.find_module(name)
module_file = module.path
logger.info('Discovered extension %s' % module_file)
if paths is None:
paths = []
self.__gather_extension_details(module_file, work_dir)
except Exception:
logger.error("Failed to load Python extensions from module file {0}. This module will be ignored.".format(module_file), exc_info=True)
for path in paths:
for finder, name, ispkg in pkgutil.iter_modules([path]):
if not require_nifi_prefix or name.startswith('nifi_'):
module_file = '<Unknown Module File>'
try:
module = finder.find_module(name)
module_file = module.path
logger.info('Discovered extension %s' % module_file)
self.__gather_extension_details(module_file, work_dir)
except Exception:
logger.error("Failed to load Python extensions from module file {0}. This module will be ignored.".format(module_file), exc_info=True)
def __gather_extension_details(self, module_file, work_dir, local_dependencies=None):
@ -280,7 +286,7 @@ class ExtensionManager:
classes_and_details = self.__get_processor_classes_and_details(module_file)
for classname, details in classes_and_details.items():
id = ExtensionId(classname, details.version)
logger.info("Found local dependencies {0} for {1}".format(local_dependencies, classname))
logger.info(f"For {classname} found local dependencies {local_dependencies}")
details.local_dependencies = local_dependencies
@ -291,8 +297,9 @@ class ExtensionManager:
self.module_files_by_extension_type[id] = module_file
def __get_download_complete_marker_file(self, work_dir, extension_type, version):
return os.path.join(work_dir, 'extensions', extension_type, version, 'dependency-download.complete')
def __get_download_complete_marker_file(self, work_dir, processor_details):
version = processor_details.version
return os.path.join(work_dir, 'extensions', processor_details.type, version, 'dependency-download.complete')
def __get_dependencies_for_extension_type(self, extension_type, version):
@ -462,9 +469,8 @@ class ExtensionManager:
def import_external_dependencies(self, processor_details, work_dir):
class_name = processor_details.getProcessorType()
extension_version = processor_details.getProcessorVersion()
completion_marker_file = self.__get_download_complete_marker_file(work_dir, class_name, extension_version)
completion_marker_file = self.__get_download_complete_marker_file(work_dir, processor_details)
target_dir = os.path.dirname(completion_marker_file)
if not os.path.exists(target_dir):
@ -474,6 +480,21 @@ class ExtensionManager:
logger.info("All dependencies have already been imported for {0}".format(class_name))
return True
python_cmd = os.getenv("PYTHON_CMD")
if processor_details.source_location is not None:
package_dir = os.path.dirname(processor_details.source_location)
requirements_file = os.path.join(package_dir, 'requirements.txt')
if os.path.exists(requirements_file):
args = [python_cmd, '-m', 'pip', 'install', '--target', target_dir, '-r', requirements_file]
logger.info(f"Importing dependencies from requirements file for package {package_dir} to {target_dir} using command {args}")
result = subprocess.run(args)
if result.returncode == 0:
logger.info(f"Successfully imported requirements for package {package_dir} to {target_dir}")
else:
raise RuntimeError(f"Failed to import requirements for package {package_dir} from requirements.txt file: process exited with status code {result}")
dependencies = processor_details.getDependencies()
if len(dependencies) > 0:
python_cmd = os.getenv("PYTHON_CMD")
@ -498,34 +519,49 @@ class ExtensionManager:
def __load_extension_module(self, file, local_dependencies):
# If there are any local dependencies (i.e., other python files in the same directory), load those modules first
if local_dependencies is not None:
for local_dependency in local_dependencies:
if local_dependency == file:
continue
if local_dependencies is not None and len(local_dependencies) > 0:
to_load = [dep for dep in local_dependencies]
if file in to_load:
to_load.remove(file)
logger.debug(f"Loading local dependency {local_dependency} before loading {file}")
self.__load_extension_module(local_dependency, None)
# There is almost certainly a better way to do this. But we need to load all modules that are 'local dependencies'. I.e., all
# modules in the same directory/package. But Python does not appear to give us a simple way to do this. We could have a situation in which
# we have:
# Module A depends on B
# Module C depends on B
# Module B has no dependencies
# But we don't know the order of the dependencies so if we attempt to import Module A or C first, we get an ImportError because Module B hasn't
# been imported. To address this, we create a queue of dependencies. If we attempt to import one and it fails, we insert it at the front of the queue
# so that it will be tried again after trying all dependencies. After we attempt to load a dependency 10 times, we give up and re-throw the error.
attempts = {}
for dep in to_load:
attempts[dep] = 0
while len(to_load) > 0:
local_dependency = to_load.pop()
try:
logger.debug(f"Loading local dependency {local_dependency} before loading {file}")
self.__load_extension_module(local_dependency, None)
except:
previous_attempts = attempts[local_dependency]
if previous_attempts >= 10:
raise
attempts[local_dependency] = previous_attempts + 1
logger.debug(f"Failed to load local dependency {local_dependency}. Will try again after all have been attempted", exc_info=True)
to_load.insert(0, local_dependency)
# Determine the module name
moduleName = Path(file).name.split('.py')[0]
# Create the module specification
moduleSpec = importlib.util.spec_from_file_location(moduleName, file)
logger.debug('Module Spec: %s' % moduleSpec)
logger.debug(f"Module Spec: {moduleSpec}")
# Create the module from the specification
module = importlib.util.module_from_spec(moduleSpec)
logger.debug('Module: %s' % module)
# Initialize the JvmHolder class with the gateway jvm.
# This must be done before executing the module to ensure that the nifiapi module
# is able to access the JvmHolder.jvm variable. This enables the nifiapi.properties.StandardValidators, etc. to be used
# However, we have to delay the import until this point, rather than adding it to the top of the ExtensionManager class
# because we need to ensure that we've fetched the appropriate dependencies for the pyenv environment for the extension point.
from nifiapi.__jvm__ import JvmHolder
JvmHolder.jvm = self.gateway.jvm
JvmHolder.gateway = self.gateway
logger.debug(f"Module: {module}")
# Load the module
sys.modules[moduleName] = module

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from nifiapi.properties import ProcessContext
# PythonProcessorAdapter is responsible for receiving method invocations from Java side and delegating to the appropriate
# method for a Processor. We use this adapter instead of calling directly into the Processor because it allows us to be more
@ -53,7 +55,7 @@ class PythonProcessorAdapter:
if not self.hasCustomValidate:
return None
return self.processor.customValidate(context)
return self.processor.customValidate(ProcessContext(context))
def getRelationships(self):
# If self.relationships is None, it means that the Processor has implemented the method, and we need
@ -86,11 +88,11 @@ class PythonProcessorAdapter:
def onScheduled(self, context):
if self.hasMethod(self.processor, 'onScheduled'):
self.processor.onScheduled(context)
self.processor.onScheduled(ProcessContext(context))
def onStopped(self, context):
if self.hasMethod(self.processor, 'onStopped'):
self.processor.onStopped(context)
self.processor.onStopped(ProcessContext(context))
def initialize(self, context):
self.processor.logger = context.getLogger()

View File

@ -0,0 +1,49 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nifi-python-extensions</artifactId>
<groupId>org.apache.nifi</groupId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<packaging>jar</packaging>
<artifactId>nifi-openai-module</artifactId>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<configuration>
<includeEmptyDirs>true</includeEmptyDirs>
<resources>
<resource>
<directory>src/main/python</directory>
<includes>
<include>**/</include>
</includes>
</resource>
</resources>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,219 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, PropertyDependency, ExpressionLanguageScope, TimeUnit
FLOWFILE_CONTENT = 'flowfile_content'
FLOWFILE_CONTENT_REFERENCE = '{' + FLOWFILE_CONTENT + '}'
# Regex to match { followed by any number of characters other than { or }, followed by }. But do not match if it starts with {{
VAR_NAME_REGEX = r'(?<!{)\{([^{]*?)\}'
class PromptChatGPT(FlowFileTransform):
class Java:
implements = ['org.apache.nifi.python.processor.FlowFileTransform']
class ProcessorDetails:
version = '2.0.0-SNAPSHOT'
description = "Submits a prompt to ChatGPT, writing the results either to a FlowFile attribute or to the contents of the FlowFile"
tags = ["text", "chatgpt", "gpt", "machine learning", "ML", "artificial intelligence", "ai", "document", "langchain"]
dependencies = ['langchain', 'openai', 'jsonpath-ng']
MODEL = PropertyDescriptor(
name="OpenAI Model Name",
description="The name of the OpenAI Model to use in order to answer the prompt",
default_value="gpt-3.5-turbo",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True
)
PROMPT = PropertyDescriptor(
name="Prompt",
description="The prompt to issue to ChatGPT. This may use FlowFile attributes via Expression Language and may also reference the FlowFile content by using the literal " +
f"{FLOWFILE_CONTENT_REFERENCE} (including braces) in the prompt. If the FlowFile's content is JSON formatted, a reference may also include JSONPath Expressions "
"to reference specific fields in the FlowFile content, such as {$.page_content}",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=True
)
TEMPERATURE = PropertyDescriptor(
name="Temperature",
description="The Temperature parameter to submit to OpenAI. A lower value will result in more consistent answers while a higher value will result in a more creative answer. " +
"The value must be between 0 and 2, inclusive.",
validators=[StandardValidators._standard_validators.createNonNegativeFloatingPointValidator(2.0)],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=True,
default_value="1.0"
)
RESULT_ATTRIBUTE = PropertyDescriptor(
name="Result Attribute",
description="If specified, the result will be added to the attribute whose name is given. If not specified, the result will be written to the FlowFile's content",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False
)
API_KEY = PropertyDescriptor(
name="API Key",
description="The OpenAI API Key to use",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True
)
TIMEOUT = PropertyDescriptor(
name="Request Timeout",
description="The amount of time to wait before timing out the request",
validators=[StandardValidators.TIME_PERIOD_VALIDATOR],
default_value="60 secs",
required=True
)
MAX_TOKENS = PropertyDescriptor(
name="Max Tokens to Generate",
description="The maximum number of tokens that ChatGPT should generate",
validators=[StandardValidators.POSITIVE_INTEGER_VALIDATOR],
required=False
)
ORGANIZATION = PropertyDescriptor(
name="OpenAI Organization ID",
description="The OpenAI Organization ID",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False
)
API_BASE = PropertyDescriptor(
name="API Base URL Path",
description="The API Base URL to use for interacting with OpenAI. This should be populated only if using a proxy or an emulator.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False
)
property_descriptors = [
MODEL,
PROMPT,
TEMPERATURE,
RESULT_ATTRIBUTE,
API_KEY,
TIMEOUT,
MAX_TOKENS,
ORGANIZATION,
API_BASE
]
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.property_descriptors
def transform(self, context, flowFile):
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains.llm import LLMChain
prompt = context.getProperty(self.PROMPT).evaluateAttributeExpressions(flowFile).getValue()
# We want to allow referencing FlowFile content using JSONPath Expressions.
# To do that, we allow the same {variable} syntax as Langchain. But Langchain does not allow '$' characters
# to exist in the variable names. So we need to replace those variables in the prompt with new variables, such as
# jsonpath_var_0, jsonpath_var_1, etc. To do this, we will use a Regex to detect any variables that are referenced
# and if it starts with a $ we will replace it with jsonpath_var_<index> and we will keep a mapping from that name to
# the substituted variable name so that we can later determine what the JSONPath expression was.
variable_references = list(set(re.findall(VAR_NAME_REGEX, prompt)))
input_variables = []
jsonpath_to_var_mapping = {}
index = 0
for ref in variable_references:
if ref.startswith("$"):
var_name = "jsonpath_var_" + str(index)
index += 1
input_variables.append(var_name)
jsonpath_to_var_mapping[ref] = var_name
prompt = prompt.replace("{" + ref + "}", "{" + var_name + "}")
elif ref == FLOWFILE_CONTENT:
input_variables.append(ref)
else:
raise ValueError("Prompt contained an invalid variable reference: {" + ref + "}. Valid references are flowfile_content or any JSONPath expression.")
temperature = context.getProperty(self.TEMPERATURE).evaluateAttributeExpressions(flowFile).asFloat()
model_name = context.getProperty(self.MODEL).evaluateAttributeExpressions(flowFile).getValue()
api_key = context.getProperty(self.API_KEY).getValue()
timeout = context.getProperty(self.TIMEOUT).asTimePeriod(TimeUnit.SECONDS)
max_tokens = context.getProperty(self.MAX_TOKENS).asInteger()
organization = context.getProperty(self.ORGANIZATION).getValue()
api_base = context.getProperty(self.API_BASE).getValue()
# Build out our LLMChain
llm = ChatOpenAI(model_name=model_name, temperature=temperature, openai_api_key=api_key, request_timeout=timeout, max_retries=0,
max_tokens=max_tokens, openai_organization=organization, openai_api_base=api_base)
prompt_template = PromptTemplate(
template=prompt,
input_variables=input_variables
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt_template
)
# Substitute in any JSON Path Expressions or references to {flowfile_content}.
llm_args = {}
json_content = None
for var_name in variable_references:
# If variable references {flowfile_content} substitute the content
if var_name == FLOWFILE_CONTENT:
llm_args[FLOWFILE_CONTENT] = flowFile.getContentsAsBytes().decode()
if var_name.startswith("$"):
# Load the FlowFile's contents into the json_content variable only once
if json_content is None:
json_content = json.loads(flowFile.getContentsAsBytes().decode())
# Import jsonpath_ng so that we can evaluate JSONPath against the FlowFile content.
from jsonpath_ng import parse
try:
jsonpath_expression = parse(var_name)
matches = jsonpath_expression.find(json_content)
variable_value = "\n".join([match.value for match in matches])
except:
self.logger.error("Invalid JSONPath reference in prompt: " + var_name)
raise
# Insert the resolved value into llm_args
resolved_var_name = jsonpath_to_var_mapping.get(var_name)
llm_args[resolved_var_name] = variable_value
self.logger.debug(f"Evaluating prompt\nPrompt: {prompt}\nArgs: #{llm_args}")
# Run the LLM Chain in order to prompt ChatGPT
results = llm_chain(llm_args)
# Create the output content or FLowFile attribute
text = results['text']
attribute_name = context.getProperty(self.RESULT_ATTRIBUTE).getValue()
if attribute_name is None:
output_content = text
output_attributes = None
else:
output_content = None
output_attributes = {attribute_name: text}
# Return the results
return FlowFileTransformResult("success", contents=output_content, attributes=output_attributes)

View File

@ -0,0 +1,97 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nifi-python-extensions</artifactId>
<groupId>org.apache.nifi</groupId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<artifactId>nifi-python-extensions-bundle</artifactId>
<dependencies>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-openai-module</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-text-embeddings-module</artifactId>
<version>2.0.0-SNAPSHOT</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-dependency-plugin</artifactId>
<executions>
<execution>
<id>unpack-python-extensions</id>
<goals>
<goal>unpack-dependencies</goal>
</goals>
<phase>generate-test-resources</phase>
<configuration>
<excludeTransitive>true</excludeTransitive>
<excludes>META-INF, META-INF/**</excludes>
<outputDirectory>${project.build.directory}/python</outputDirectory>
<includeScope>runtime</includeScope>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<finalName>nifi-python-extensions-bundle</finalName>
<appendAssemblyId>false</appendAssemblyId>
<attach>true</attach>
</configuration>
<executions>
<execution>
<id>package</id>
<goals>
<goal>single</goal>
</goals>
<phase>generate-test-resources</phase>
<configuration>
<archiverConfig>
<defaultDirectoryMode>0775</defaultDirectoryMode>
<directoryMode>0775</directoryMode>
<fileMode>0664</fileMode>
</archiverConfig>
<descriptors>
<descriptor>src/main/assembly/dependencies.xml</descriptor>
</descriptors>
<tarLongFileMode>posix</tarLongFileMode>
<formats>
<format>zip</format>
</formats>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<assembly>
<id>assembly</id>
<includeBaseDirectory>false</includeBaseDirectory>
<baseDirectory>./</baseDirectory>
<fileSets>
<fileSet>
<directory>${project.build.directory}/python</directory>
<outputDirectory>.</outputDirectory>
</fileSet>
</fileSets>
</assembly>

View File

@ -0,0 +1,50 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nifi-python-extensions</artifactId>
<groupId>org.apache.nifi</groupId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<packaging>jar</packaging>
<artifactId>nifi-text-embeddings-module</artifactId>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-resources-plugin</artifactId>
<configuration>
<includeEmptyDirs>true</includeEmptyDirs>
<resources>
<resource>
<directory>src/main/python</directory>
<includes>
<include>**/</include>
</includes>
</resource>
</resources>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,211 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from langchain.text_splitter import Language
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, PropertyDependency, ExpressionLanguageScope
SPLIT_BY_CHARACTER = 'Split by Character'
SPLIT_CODE = 'Split Code'
RECURSIVELY_SPLIT_BY_CHARACTER = 'Recursively Split by Character'
TEXT_KEY = "text"
METADATA_KEY = "metadata"
class ChunkDocument(FlowFileTransform):
class Java:
implements = ['org.apache.nifi.python.processor.FlowFileTransform']
class ProcessorDetails:
version = '2.0.0-SNAPSHOT'
description = """Splits incoming documents into chunks that are appropriately sized for creating Text Embeddings. The input is expected to be in "json-lines" format, with
each line having a 'text' and a 'metadata' element. Each line will then be split into one or more lines in the output."""
tags = ["text", "split", "chunk", "langchain", "embeddings", "vector", "machine learning", "ML", "artificial intelligence", "ai", "document"]
dependencies = ['langchain']
CHUNK_STRATEGY = PropertyDescriptor(
name="Chunking Strategy",
description="Specifies which splitter should be used to split the text",
allowable_values=[RECURSIVELY_SPLIT_BY_CHARACTER, SPLIT_BY_CHARACTER, SPLIT_CODE],
required=True,
default_value=RECURSIVELY_SPLIT_BY_CHARACTER
)
SEPARATOR = PropertyDescriptor(
name="Separator",
description="Specifies the character sequence to use for splitting apart the text. If using a Chunking Strategy of Recursively Split by Character, " +
"it is a comma-separated list of character sequences. Meta-characters \\n, \\r and \\t are automatically un-escaped.",
required=True,
default_value="\\n\\n,\\n, ,",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
dependencies=[PropertyDependency(CHUNK_STRATEGY, SPLIT_BY_CHARACTER, RECURSIVELY_SPLIT_BY_CHARACTER)],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
SEPARATOR_FORMAT = PropertyDescriptor(
name="Separator Format",
description="Specifies how to interpret the value of the <Separator> property",
required=True,
default_value="Plain Text",
allowable_values=["Plain Text", "Regular Expression"],
dependencies=[PropertyDependency(CHUNK_STRATEGY, SPLIT_BY_CHARACTER, RECURSIVELY_SPLIT_BY_CHARACTER)]
)
CHUNK_SIZE = PropertyDescriptor(
name="Chunk Size",
description="The maximum size of a chunk that should be returned",
required=True,
default_value="4000",
validators=[StandardValidators.POSITIVE_INTEGER_VALIDATOR]
)
CHUNK_OVERLAP = PropertyDescriptor(
name="Chunk Overlap",
description="The number of characters that should be overlapped between each chunk of text",
required=True,
default_value="200",
validators=[StandardValidators.NON_NEGATIVE_INTEGER_VALIDATOR]
)
KEEP_SEPARATOR = PropertyDescriptor(
name="Keep Separator",
description="Whether or not to keep the text separator in each chunk of data",
required=True,
default_value="false",
allowable_values=["true", "false"],
dependencies=[PropertyDependency(CHUNK_STRATEGY, SPLIT_BY_CHARACTER, RECURSIVELY_SPLIT_BY_CHARACTER)]
)
STRIP_WHITESPACE = PropertyDescriptor(
name="Strip Whitespace",
description="Whether or not to strip the whitespace at the beginning and end of each chunk",
required=True,
default_value="true",
allowable_values=["true", "false"],
dependencies=[PropertyDependency(CHUNK_STRATEGY, SPLIT_BY_CHARACTER, RECURSIVELY_SPLIT_BY_CHARACTER)]
)
LANGUAGE = PropertyDescriptor(
name="Language",
description="The language to use for the Code's syntax",
required=True,
default_value="python",
allowable_values=[e.value for e in Language],
dependencies=[PropertyDependency(CHUNK_STRATEGY, SPLIT_CODE)]
)
property_descriptors = [CHUNK_STRATEGY,
SEPARATOR,
SEPARATOR_FORMAT,
CHUNK_SIZE,
CHUNK_OVERLAP,
KEEP_SEPARATOR,
STRIP_WHITESPACE]
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.property_descriptors
def split_docs(self, context, flowfile, documents):
from langchain.text_splitter import CharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter
strategy = context.getProperty(self.CHUNK_STRATEGY).getValue()
if strategy == SPLIT_BY_CHARACTER:
text_splitter = CharacterTextSplitter(
separator = context.getProperty(self.SEPARATOR).evaluateAttributeExpressions(flowfile).getValue(),
keep_separator = context.getProperty(self.KEEP_SEPARATOR).asBoolean(),
is_separator_regex = context.getProperty(self.SEPARATOR_FORMAT).getValue() == 'Regular Expression',
chunk_size = context.getProperty(self.CHUNK_SIZE).asInteger(),
chunk_overlap = context.getProperty(self.CHUNK_OVERLAP).asInteger(),
length_function = len,
strip_whitespace = context.getProperty(self.STRIP_WHITESPACE).asBoolean()
)
elif strategy == SPLIT_CODE:
text_splitter = RecursiveCharacterTextSplitter.from_language(
language=context.getProperty(self.LANGUAGE).getValue(),
chunk_size = context.getProperty(self.CHUNK_SIZE).asInteger(),
chunk_overlap = context.getProperty(self.CHUNK_OVERLAP).asInteger()
)
else:
separator_text = context.getProperty(self.SEPARATOR).evaluateAttributeExpressions(flowfile).getValue()
splits = separator_text.split(",")
unescaped = []
for split in splits:
unescaped.append(split.replace("\\n", "\n").replace("\\r", "\r").replace("\\t", "\t"))
text_splitter = RecursiveCharacterTextSplitter(
separators = unescaped,
keep_separator = context.getProperty(self.KEEP_SEPARATOR).asBoolean(),
is_separator_regex = context.getProperty(self.SEPARATOR_FORMAT).getValue() == 'Regular Expression',
chunk_size = context.getProperty(self.CHUNK_SIZE).asInteger(),
chunk_overlap = context.getProperty(self.CHUNK_OVERLAP).asInteger(),
length_function = len,
strip_whitespace = context.getProperty(self.STRIP_WHITESPACE).asBoolean()
)
splits = text_splitter.split_documents(documents)
return splits
def to_json(self, docs) -> str:
json_docs = []
i = 0
for doc in docs:
doc.metadata['chunk_index'] = i
doc.metadata['chunk_count'] = len(docs)
i += 1
json_doc = json.dumps({
TEXT_KEY: doc.page_content,
METADATA_KEY: doc.metadata
})
json_docs.append(json_doc)
return "\n".join(json_docs)
def load_docs(self, flowfile):
from langchain.schema import Document
flowfile_contents = flowfile.getContentsAsBytes().decode()
docs = []
for line in flowfile_contents.split("\n"):
stripped = line.strip()
if stripped == "":
continue
json_element = json.loads(stripped)
page_content = json_element.get(TEXT_KEY)
if page_content is None:
continue
metadata = json_element.get(METADATA_KEY)
if metadata is None:
metadata = {}
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
return docs
def transform(self, context, flowfile):
documents = self.load_docs(flowfile)
split_docs = self.split_docs(context, flowfile, documents)
output_json = self.to_json(split_docs)
attributes = {"document.count": str(len(split_docs))}
return FlowFileTransformResult("success", contents=output_json, attributes=attributes)

View File

@ -0,0 +1,260 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import json
from typing import List
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, PropertyDependency
PLAIN_TEXT = "Plain Text"
HTML = "HTML"
MARKDOWN = "Markdown"
PDF = "PDF"
EXCEL = "Microsoft Excel"
POWERPOINT = "Microsoft PowerPoint"
WORD = "Microsoft Word"
PARSING_STRATEGY_AUTO = "Automatic"
PARSING_STRATEGY_HIGH_RES = "High Resolution"
PARSING_STRATEGY_OCR_ONLY = "OCR Only"
PARSING_STRATEGY_FAST = "Fast"
SINGLE_DOCUMENT = "Single Document"
DOCUMENT_PER_ELEMENT = "Document Per Element"
TEXT_KEY = "text"
METADATA_KEY = "metadata"
class ParseDocument(FlowFileTransform):
class Java:
implements = ["org.apache.nifi.python.processor.FlowFileTransform"]
class ProcessorDetails:
version = "2.0.0-SNAPSHOT"
description = """Parses incoming unstructured text documents and performs optical character recognition (OCR) in order to extract text from PDF and image files.
The output is formatted as "json-lines" with two keys: 'text' and 'metadata'.
Note that use of this Processor may require significant storage space and RAM utilization due to third-party dependencies necessary for processing PDF and image files.
Also note that in order to process PDF or Images, Tesseract and Poppler must be installed on the system."""
tags = ["text", "embeddings", "vector", "machine learning", "ML", "artificial intelligence", "ai", "document", "langchain", "pdf", "html", "markdown", "word", "excel", "powerpoint"]
dependencies = ['langchain', 'unstructured', 'unstructured-inference', 'unstructured_pytesseract', 'numpy',
'opencv-python', 'pdf2image', 'pdfminer.six[image]', 'python-docx', 'openpyxl', 'python-pptx']
INPUT_FORMAT = PropertyDescriptor(
name="Input Format",
description="""The format of the input FlowFile. This dictates which TextLoader will be used to parse the input.
Note that in order to process images or extract tables from PDF files,you must have both 'poppler' and 'tesseract' installed on your system.""",
allowable_values=[PLAIN_TEXT, HTML, MARKDOWN, PDF, WORD, EXCEL, POWERPOINT],
required=True,
default_value=PLAIN_TEXT
)
PDF_PARSING_STRATEGY = PropertyDescriptor(
name="PDF Parsing Strategy",
display_name="Parsing Strategy",
description="Specifies the strategy to use when parsing a PDF",
allowable_values=[PARSING_STRATEGY_AUTO, PARSING_STRATEGY_HIGH_RES, PARSING_STRATEGY_OCR_ONLY, PARSING_STRATEGY_FAST],
required=True,
default_value=PARSING_STRATEGY_AUTO,
dependencies=[PropertyDependency(INPUT_FORMAT, PDF)]
)
PDF_MODEL_NAME = PropertyDescriptor(
name="PDF Parsing Model",
description="The model to use for parsing. Different models will have their own strengths and weaknesses.",
allowable_values=["yolox", "detectron2_onnx", "chipper"],
required=True,
default_value="yolox",
dependencies=[PropertyDependency(INPUT_FORMAT, PDF)]
)
ELEMENT_STRATEGY = PropertyDescriptor(
name="Element Strategy",
description="Specifies whether the input should be loaded as a single Document, or if each element in the input should be separated out into its own Document",
allowable_values=[SINGLE_DOCUMENT, DOCUMENT_PER_ELEMENT],
required=True,
default_value=DOCUMENT_PER_ELEMENT,
dependencies=[PropertyDependency(INPUT_FORMAT, HTML, MARKDOWN)]
)
INCLUDE_PAGE_BREAKS = PropertyDescriptor(
name="Include Page Breaks",
description="Specifies whether or not page breaks should be considered when creating Documents from the input",
allowable_values=["true", "false"],
required=True,
default_value="false",
dependencies=[PropertyDependency(INPUT_FORMAT, HTML, MARKDOWN),
PropertyDependency(ELEMENT_STRATEGY, DOCUMENT_PER_ELEMENT)]
)
PDF_INFER_TABLE_STRUCTURE = PropertyDescriptor(
name="Infer Table Structure",
description="If true, any table that is identified in the PDF will be parsed and translated into an HTML structure. The HTML of that table will then be added to the \
Document's metadata in a key named 'text_as_html'. Regardless of the value of this property, the textual contents of the table will be written to the contents \
without the structure.",
allowable_values=["true", "false"],
default_value="false",
required=True,
dependencies=[PropertyDependency(PDF_PARSING_STRATEGY, PARSING_STRATEGY_HIGH_RES)]
)
LANGUAGES = PropertyDescriptor(
name="Languages",
description="A comma-separated list of language codes that should be used when using OCR to determine the text.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="Eng",
required=True,
dependencies=[PropertyDependency(INPUT_FORMAT, PDF)]
)
METADATA_FIELDS = PropertyDescriptor(
name="Metadata Fields",
description="A comma-separated list of FlowFile attributes that will be added to the Documents' Metadata",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="filename, uuid",
required=True
)
EXTRACT_METADATA = PropertyDescriptor(
name="Include Extracted Metadata",
description="Whether or not to include the metadata that is extracted from the input in each of the Documents",
allowable_values=["true", "false"],
default_value="true",
required=True
)
property_descriptors = [INPUT_FORMAT,
PDF_PARSING_STRATEGY,
PDF_MODEL_NAME,
ELEMENT_STRATEGY,
INCLUDE_PAGE_BREAKS,
PDF_INFER_TABLE_STRUCTURE,
LANGUAGES,
METADATA_FIELDS,
EXTRACT_METADATA]
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.property_descriptors
def get_parsing_strategy(self, nifi_value:str, default_value: str) -> str:
if nifi_value == PARSING_STRATEGY_OCR_ONLY:
return "ocr_only"
if nifi_value == PARSING_STRATEGY_HIGH_RES:
return "hi_res"
if nifi_value == PARSING_STRATEGY_FAST:
return "fast"
if nifi_value == PARSING_STRATEGY_AUTO:
return "auto"
return default_value
def get_languages(self, nifi_value: str) -> List[str]:
return [
lang.strip()
for lang in nifi_value.split(",")
]
def create_docs(self, context, flowFile):
from langchain.schema import Document
metadata = {}
for attribute_name in context.getProperty(self.METADATA_FIELDS).getValue().split(","):
trimmed = attribute_name.strip()
value = flowFile.getAttribute(trimmed)
metadata[trimmed] = value
input_format = context.getProperty(self.INPUT_FORMAT).evaluateAttributeExpressions(flowFile).getValue()
if input_format == PLAIN_TEXT:
return [Document(page_content=str(flowFile.getContentsAsBytes()), metadata=metadata)]
element_strategy = context.getProperty(self.ELEMENT_STRATEGY).getValue()
if element_strategy == SINGLE_DOCUMENT:
mode = "single"
else:
mode = "elements"
include_page_breaks = context.getProperty(self.INCLUDE_PAGE_BREAKS).asBoolean()
include_metadata = context.getProperty(self.EXTRACT_METADATA).asBoolean()
if input_format == HTML:
from langchain.document_loaders import UnstructuredHTMLLoader
loader = UnstructuredHTMLLoader(None, file=io.BytesIO(flowFile.getContentsAsBytes()), mode=mode, include_page_breaks=include_page_breaks, include_metadata=include_metadata)
elif input_format == PDF:
from langchain.document_loaders import UnstructuredPDFLoader
infer_table_structure = context.getProperty(self.PDF_INFER_TABLE_STRUCTURE).asBoolean()
strategy = self.get_parsing_strategy(context.getProperty(self.PDF_PARSING_STRATEGY).getValue(), PARSING_STRATEGY_AUTO)
languages = self.get_languages(context.getProperty(self.LANGUAGES).getValue())
model_name = context.getProperty(self.PDF_MODEL_NAME).getValue()
loader = UnstructuredPDFLoader(None, file=io.BytesIO(flowFile.getContentsAsBytes()), mode=mode, infer_table_structure=infer_table_structure,
include_page_breaks=include_page_breaks, languages=languages, strategy=strategy, include_metadata=include_metadata, model_name=model_name)
elif input_format == MARKDOWN:
from langchain.document_loaders import UnstructuredMarkdownLoader
loader = UnstructuredMarkdownLoader(None, file=io.BytesIO(flowFile.getContentsAsBytes()), mode=mode, include_page_breaks=include_page_breaks, include_metadata=include_metadata)
elif input_format == WORD:
from langchain.document_loaders import UnstructuredWordDocumentLoader
loader = UnstructuredWordDocumentLoader(None, file=io.BytesIO(flowFile.getContentsAsBytes()), mode=mode, include_page_breaks=include_page_breaks, include_metadata=include_metadata)
elif input_format == EXCEL:
from langchain.document_loaders import UnstructuredExcelLoader
loader = UnstructuredExcelLoader(None, file=io.BytesIO(flowFile.getContentsAsBytes()), mode=mode, include_page_breaks=include_page_breaks, include_metadata=include_metadata)
elif input_format == POWERPOINT:
from langchain.document_loaders import UnstructuredPowerPointLoader
loader = UnstructuredPowerPointLoader(None, file=io.BytesIO(flowFile.getContentsAsBytes()), mode=mode, include_page_breaks=include_page_breaks, include_metadata=include_metadata)
else:
raise ValueError("Configured Input Format is invalid: " + input_format)
documents = loader.load()
if len(metadata) > 0:
for doc in documents:
if doc.metadata is None:
doc.metadata = metadata
else:
doc.metadata.update(metadata)
return documents
def to_json(self, docs) -> str:
json_docs = []
i = 0
for doc in docs:
doc.metadata['chunk_index'] = i
doc.metadata['chunk_count'] = len(docs)
i += 1
json_doc = json.dumps({
"text": doc.page_content,
"metadata": doc.metadata
})
json_docs.append(json_doc)
return "\n".join(json_docs)
def transform(self, context, flowFile):
documents = self.create_docs(context, flowFile)
output_json = self.to_json(documents)
return FlowFileTransformResult("success", contents=output_json, attributes={"mime.type": "application/json"})

View File

@ -0,0 +1,155 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nifiapi.properties import PropertyDescriptor, StandardValidators, PropertyDependency, ExpressionLanguageScope
# Connection Strategies
LOCAL_DISK = "Local Disk"
REMOTE_SERVER = "Remote Chroma Server"
# Authentication Strategies
TOKEN = "Token Authentication"
BASIC_AUTH = "Basic Authentication"
NONE = "None"
# Transport Protocols
HTTP = "http"
HTTPS = "https"
CONNECTION_STRATEGY = PropertyDescriptor(
name="Connection Strategy",
description="Specifies how to connect to the Chroma server",
allowable_values=[LOCAL_DISK, REMOTE_SERVER],
default_value=REMOTE_SERVER,
required=True
)
DIRECTORY = PropertyDescriptor(
name="Directory",
description="The Directory that Chroma should use to persist data",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
default_value="./chroma",
dependencies=[PropertyDependency(CONNECTION_STRATEGY, LOCAL_DISK)]
)
HOSTNAME = PropertyDescriptor(
name="Hostname",
description="The hostname to connect to in order to communicate with Chroma",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="localhost",
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)]
)
PORT = PropertyDescriptor(
name="Port",
description="The port that the Chroma server is listening on",
validators=[StandardValidators.PORT_VALIDATOR],
default_value="8000",
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)]
)
TRANSPORT_PROTOCOL = PropertyDescriptor(
name="Transport Protocol",
description="Specifies whether connections should be made over http or https",
allowable_values=[HTTP, HTTPS],
default_value=HTTPS,
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)]
)
AUTH_STRATEGY = PropertyDescriptor(
name="Authentication Strategy",
description="Specifies how to authenticate to Chroma server",
allowable_values=[TOKEN, BASIC_AUTH, NONE],
default_value=TOKEN,
required=True,
dependencies=[PropertyDependency(CONNECTION_STRATEGY, REMOTE_SERVER)]
)
AUTH_TOKEN = PropertyDescriptor(
name="Authentication Token",
description="The token to use for authenticating to Chroma server",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
dependencies=[PropertyDependency(AUTH_STRATEGY, TOKEN)]
)
USERNAME = PropertyDescriptor(
name="Username",
description="The username to use for authenticating to Chroma server",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
dependencies=[PropertyDependency(AUTH_STRATEGY, BASIC_AUTH)]
)
PASSWORD = PropertyDescriptor(
name="Password",
description="The password to use for authenticating to Chroma server",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
dependencies=[PropertyDependency(AUTH_STRATEGY, BASIC_AUTH)]
)
COLLECTION_NAME = PropertyDescriptor(
name="Collection Name",
description="The name of the Chroma Collection",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
default_value="nifi",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
PROPERTIES = [
CONNECTION_STRATEGY,
DIRECTORY,
HOSTNAME,
PORT,
TRANSPORT_PROTOCOL,
AUTH_STRATEGY,
AUTH_TOKEN,
USERNAME,
PASSWORD,
COLLECTION_NAME
]
def create_client(context):
import chromadb
from chromadb import Settings
connection_strategy = context.getProperty(CONNECTION_STRATEGY).getValue()
if connection_strategy == LOCAL_DISK:
directory = context.getProperty(DIRECTORY).getValue()
return chromadb.PersistentClient(directory)
else:
hostname = context.getProperty(HOSTNAME).getValue()
port = context.getProperty(PORT).asInteger()
headers = {}
ssl = context.getProperty(TRANSPORT_PROTOCOL).getValue() == HTTPS
auth_strategy = context.getProperty(AUTH_STRATEGY).getValue()
if auth_strategy == TOKEN:
auth_provider = "chromadb.auth.token.TokenAuthClientProvider"
credentials = context.getProperty(AUTH_TOKEN).getValue()
elif auth_strategy == BASIC_AUTH:
auth_provider = "chromadb.auth.basic.BasicAuthClientProvider"
username = context.getProperty(USERNAME).getValue()
password = context.getProperty(PASSWORD).getValue()
credentials = username + ":" + password
else:
auth_provider = None
credentials = None
settings = Settings(
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=credentials
)
return chromadb.HttpClient(hostname, port, ssl, headers, settings)

View File

@ -0,0 +1,147 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nifiapi.properties import PropertyDescriptor, StandardValidators, PropertyDependency, ExpressionLanguageScope
# Embedding Functions
ONNX_ALL_MINI_LM_L6_V2 = "ONNX all-MiniLM-L6-v2 Model"
HUGGING_FACE = "Hugging Face Model"
OPENAI = "OpenAI Model"
SENTENCE_TRANSFORMERS = "Sentence Transformers"
EMBEDDING_FUNCTION = PropertyDescriptor(
name="Embedding Function",
description="Specifies which embedding function should be used in order to create embeddings from incoming Documents",
allowable_values=[ONNX_ALL_MINI_LM_L6_V2, HUGGING_FACE, OPENAI, SENTENCE_TRANSFORMERS],
default_value=ONNX_ALL_MINI_LM_L6_V2,
required=True
)
HUGGING_FACE_MODEL_NAME = PropertyDescriptor(
name="HuggingFace Model Name",
description="The name of the HuggingFace model to use",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="sentence-transformers/all-MiniLM-L6-v2",
required=True,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, HUGGING_FACE)]
)
HUGGING_FACE_API_KEY = PropertyDescriptor(
name="HuggingFace API Key",
description="The API Key for interacting with HuggingFace",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, HUGGING_FACE)]
)
OPENAI_API_KEY = PropertyDescriptor(
name="OpenAI API Key",
description="The API Key for interacting with OpenAI",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=True,
sensitive=True,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, OPENAI)]
)
OPENAI_MODEL_NAME = PropertyDescriptor(
name="OpenAI Model Name",
description="The name of the OpenAI model to use",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="text-embedding-ada-002",
required=True,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, OPENAI)]
)
OPENAI_ORGANIZATION = PropertyDescriptor(
name="OpenAI Organization ID",
description="The OpenAI Organization ID",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, OPENAI)]
)
OPENAI_API_BASE = PropertyDescriptor(
name="OpenAI API Base Path",
description="The API Base to use for interacting with OpenAI. This is used for interacting with different deployments, such as an Azure deployment.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, OPENAI)]
)
OPENAI_API_TYPE = PropertyDescriptor(
name="OpenAI API Deployment Type",
description="The type of the OpenAI API Deployment. This is used for interacting with different deployments, such as an Azure deployment.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, OPENAI)]
)
OPENAI_API_VERSION = PropertyDescriptor(
name="OpenAI API Version",
description="The OpenAI API Version. This is used for interacting with different deployments, such as an Azure deployment.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, OPENAI)]
)
SENTENCE_TRANSFORMER_MODEL_NAME = PropertyDescriptor(
name="Sentence Transformer Model Name",
description="The name of the Sentence Transformer model to use",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="all-MiniLM-L6-v2",
required=True,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, SENTENCE_TRANSFORMERS)]
)
SENTENCE_TRANSFORMER_DEVICE = PropertyDescriptor(
name="Sentence Transformer Device Type",
description="The type of device to use for performing the embeddings using the Sentence Transformer, such as 'cpu', 'cuda', 'mps', 'cuda:0', etc. If not specified, a GPU will be used if "
+ "possible, otherwise a CPU.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False,
dependencies=[PropertyDependency(EMBEDDING_FUNCTION, SENTENCE_TRANSFORMERS)]
)
PROPERTIES = [
EMBEDDING_FUNCTION,
HUGGING_FACE_MODEL_NAME,
HUGGING_FACE_API_KEY,
OPENAI_MODEL_NAME,
OPENAI_API_KEY,
OPENAI_ORGANIZATION,
OPENAI_API_BASE,
OPENAI_API_TYPE,
OPENAI_API_VERSION,
SENTENCE_TRANSFORMER_MODEL_NAME,
SENTENCE_TRANSFORMER_DEVICE
]
def create_embedding_function(context):
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, OpenAIEmbeddingFunction, HuggingFaceEmbeddingFunction, SentenceTransformerEmbeddingFunction
function_name = context.getProperty(EMBEDDING_FUNCTION).getValue()
if function_name == ONNX_ALL_MINI_LM_L6_V2:
return ONNXMiniLM_L6_V2()
if function_name == OPENAI:
api_key = context.getProperty(OPENAI_API_KEY).getValue()
model_name = context.getProperty(OPENAI_MODEL_NAME).getValue()
organization_id = context.getProperty(OPENAI_ORGANIZATION).getValue()
api_base = context.getProperty(OPENAI_API_BASE).getValue()
api_type = context.getProperty(OPENAI_API_TYPE).getValue()
api_version = context.getProperty(OPENAI_API_VERSION).getValue()
return OpenAIEmbeddingFunction(api_key=api_key, model_name=model_name, organization_id=organization_id, api_base=api_base, api_type=api_type, api_version=api_version)
if function_name == HUGGING_FACE:
api_key = context.getProperty(HUGGING_FACE_API_KEY).getValue()
model_name = context.getProperty(HUGGING_FACE_MODEL_NAME).getValue()
return HuggingFaceEmbeddingFunction(api_key=api_key, model_name=model_name)
model_name = context.getProperty(SENTENCE_TRANSFORMER_MODEL_NAME).getValue()
device = context.getProperty(SENTENCE_TRANSFORMER_DEVICE).getValue()
return SentenceTransformerEmbeddingFunction(model_name=model_name, device=device)

View File

@ -0,0 +1,125 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, ExpressionLanguageScope
import ChromaUtils
import EmbeddingUtils
class PutChroma(FlowFileTransform):
class Java:
implements = ['org.apache.nifi.python.processor.FlowFileTransform']
class ProcessorDetails:
version = '2.0.0-SNAPSHOT'
description = """Publishes JSON data to a Chroma VectorDB. The Incoming data must be in single JSON per Line format, each with two keys: 'text' and 'metadata'.
The text must be a string, while metadata must be a map with strings for values. Any additional fields will be ignored. If the collection name specified
does not exist, the Processor will automatically create the collection."""
tags = ["chroma", "vector", "vectordb", "embeddings", "ai", "artificial intelligence", "ml", "machine learning", "text", "LLM"]
STORE_TEXT = PropertyDescriptor(
name="Store Document Text",
description="Specifies whether or not the text of the document should be stored in Chroma. If so, both the document's text and its embedding will be stored. If not, " +
"only the vector/embedding will be stored.",
allowable_values=["true", "false"],
required=True,
default_value="true"
)
DISTANCE_METHOD = PropertyDescriptor(
name="Distance Method",
description="If the specified collection does not exist, it will be created using this Distance Method. If the collection exists, this property will be ignored.",
allowable_values=["cosine", "l2", "ip"],
default_value="cosine",
required=True
)
DOC_ID_FIELD_NAME = PropertyDescriptor(
name="Document ID Field Name",
description="Specifies the name of the field in the 'metadata' element of each document where the document's ID can be found. " +
"If not specified, an ID will be generated based on the FlowFile's filename and a one-up number.",
required=False,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
client = None
embedding_function = None
def __init__(self, **kwargs):
self.property_descriptors = [prop for prop in ChromaUtils.PROPERTIES] + [prop for prop in EmbeddingUtils.PROPERTIES]
self.property_descriptors.append(self.STORE_TEXT)
self.property_descriptors.append(self.DISTANCE_METHOD)
self.property_descriptors.append(self.DOC_ID_FIELD_NAME)
def getPropertyDescriptors(self):
return self.property_descriptors
def onScheduled(self, context):
self.client = ChromaUtils.create_client(context)
self.embedding_function = EmbeddingUtils.create_embedding_function(context)
def transform(self, context, flowfile):
client = self.client
embedding_function = self.embedding_function
collection_name = context.getProperty(ChromaUtils.COLLECTION_NAME).evaluateAttributeExpressions(flowfile).getValue()
distance_method = context.getProperty(self.DISTANCE_METHOD).getValue()
id_field_name = context.getProperty(self.DOC_ID_FIELD_NAME).evaluateAttributeExpressions(flowfile).getValue()
collection = client.get_or_create_collection(
name=collection_name,
embedding_function=embedding_function,
metadata={"hnsw:space": distance_method})
json_lines = flowfile.getContentsAsBytes().decode()
i = 0
texts = []
metadatas = []
ids = []
for line in json_lines.split("\n"):
doc = json.loads(line)
text = doc.get('text')
metadata = doc.get('metadata')
texts.append(text)
# Remove any null values, or it will cause the embedding to fail
filtered_metadata = {}
for key, value in metadata.items():
if value is not None:
filtered_metadata[key] = value
metadatas.append(filtered_metadata)
doc_id = None
if id_field_name is not None:
doc_id = metadata.get(id_field_name)
if doc_id is None:
doc_id = flowfile.getAttribute("filename") + "-" + str(i)
ids.append(doc_id)
i += 1
embeddings = embedding_function(texts)
if not context.getProperty(self.STORE_TEXT).asBoolean():
texts = None
collection.upsert(ids, embeddings, metadatas, texts)
return FlowFileTransformResult(relationship = "success")

View File

@ -0,0 +1,161 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from langchain.vectorstores import Pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, ExpressionLanguageScope
import pinecone
import json
class PutPinecone(FlowFileTransform):
class Java:
implements = ['org.apache.nifi.python.processor.FlowFileTransform']
class ProcessorDetails:
version = '2.0.0-SNAPSHOT'
description = """Publishes JSON data to Pinecone. The Incoming data must be in single JSON per Line format, each with two keys: 'text' and 'metadata'.
The text must be a string, while metadata must be a map with strings for values. Any additional fields will be ignored."""
tags = ["pinecone", "vector", "vectordb", "vectorstore", "embeddings", "ai", "artificial intelligence", "ml", "machine learning", "text", "LLM"]
PINECONE_API_KEY = PropertyDescriptor(
name="Pinecone API Key",
description="The API Key to use in order to authentication with Pinecone",
sensitive=True,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR]
)
OPENAI_API_KEY = PropertyDescriptor(
name="OpenAI API Key",
description="The API Key for OpenAI in order to create embeddings",
sensitive=True,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR]
)
PINECONE_ENV = PropertyDescriptor(
name="Pinecone Environment",
description="The name of the Pinecone Environment. This can be found in the Pinecone console next to the API Key.",
sensitive=False,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR]
)
INDEX_NAME = PropertyDescriptor(
name="Index Name",
description="The name of the Pinecone index.",
sensitive=False,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
TEXT_KEY = PropertyDescriptor(
name="Text Key",
description="The key in the document that contains the text to create embeddings for.",
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="text",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
NAMESPACE = PropertyDescriptor(
name="Namespace",
description="The name of the Pinecone Namespace to put the documents to.",
required=False,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
DOC_ID_FIELD_NAME = PropertyDescriptor(
name="Document ID Field Name",
description="Specifies the name of the field in the 'metadata' element of each document where the document's ID can be found. " +
"If not specified, an ID will be generated based on the FlowFile's filename and a one-up number.",
required=False,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
properties = [PINECONE_API_KEY,
OPENAI_API_KEY,
PINECONE_ENV,
INDEX_NAME,
TEXT_KEY,
NAMESPACE,
DOC_ID_FIELD_NAME]
embeddings = None
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.properties
def onScheduled(self, context):
api_key = context.getProperty(self.PINECONE_API_KEY).getValue()
pinecone_env = context.getProperty(self.PINECONE_ENV).getValue()
# initialize pinecone
pinecone.init(
api_key=api_key,
environment=pinecone_env,
)
openai_api_key = context.getProperty(self.OPENAI_API_KEY).getValue()
self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
def transform(self, context, flowfile):
# First, check if our index already exists. If it doesn't, we create it
index_name = context.getProperty(self.INDEX_NAME).evaluateAttributeExpressions(flowfile).getValue()
namespace = context.getProperty(self.NAMESPACE).evaluateAttributeExpressions(flowfile).getValue()
id_field_name = context.getProperty(self.DOC_ID_FIELD_NAME).evaluateAttributeExpressions(flowfile).getValue()
index = pinecone.Index(index_name)
# Read the FlowFile content as "json-lines".
json_lines = flowfile.getContentsAsBytes().decode()
i = 1
texts = []
metadatas = []
ids = []
for line in json_lines.split("\n"):
try:
doc = json.loads(line)
except Exception as e:
raise ValueError(f"Could not parse line {i} as JSON") from e
text = doc.get('text')
metadata = doc.get('metadata')
texts.append(text)
# Remove any null values, or it will cause the embedding to fail
filtered_metadata = {}
for key, value in metadata.items():
if value is not None:
filtered_metadata[key] = value
metadatas.append(filtered_metadata)
doc_id = None
if id_field_name is not None:
doc_id = metadata.get(id_field_name)
if doc_id is None:
doc_id = flowfile.getAttribute("filename") + "-" + str(i)
ids.append(doc_id)
i += 1
text_key = context.getProperty(self.TEXT_KEY).evaluateAttributeExpressions().getValue()
vectorstore = Pinecone(index, self.embeddings.embed_query, text_key)
vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids, namespace=namespace)
return FlowFileTransformResult(relationship = "success")

View File

@ -0,0 +1,159 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, ExpressionLanguageScope, PropertyDependency
import ChromaUtils
import EmbeddingUtils
import QueryUtils
class QueryChroma(FlowFileTransform):
class Java:
implements = ['org.apache.nifi.python.processor.FlowFileTransform']
class ProcessorDetails:
version = '2.0.0-SNAPSHOT'
description = "Queries a Chroma Vector Database in order to gather a specified number of documents that are most closely related to the given query."
tags = ["chroma", "vector", "vectordb", "embeddings", "enrich", "enrichment", "ai", "artificial intelligence", "ml", "machine learning", "text", "LLM"]
QUERY = PropertyDescriptor(
name="Query",
description="The query to issue to the Chroma VectorDB. The query is always converted into embeddings using the configured embedding function, and the embedding is " +
"then sent to Chroma. The text itself is not sent to Chroma.",
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
NUMBER_OF_RESULTS = PropertyDescriptor(
name="Number of Results",
description="The number of results to return from Chroma",
required=True,
validators=[StandardValidators.POSITIVE_INTEGER_VALIDATOR],
default_value="10",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
METADATA_FILTER = PropertyDescriptor(
name="Metadata Filter",
description="A JSON representation of a Metadata Filter that can be applied against the Chroma documents in order to narrow down the documents that can be returned. " +
"For example: { \"metadata_field\": \"some_value\" }",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=False
)
DOCUMENT_FILTER = PropertyDescriptor(
name="Document Filter",
description="A JSON representation of a Document Filter that can be applied against the Chroma documents' text in order to narrow down the documents that can be returned. " +
"For example: { \"$contains\": \"search_string\" }",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES,
required=False
)
client = None
embedding_function = None
include_ids = None
include_metadatas = None
include_documents = None
include_distances = None
include_embeddings = None
results_field = None
property_descriptors = [prop for prop in ChromaUtils.PROPERTIES] + [prop for prop in EmbeddingUtils.PROPERTIES] + [
QUERY,
NUMBER_OF_RESULTS,
QueryUtils.OUTPUT_STRATEGY,
QueryUtils.RESULTS_FIELD,
METADATA_FILTER,
DOCUMENT_FILTER,
QueryUtils.INCLUDE_IDS,
QueryUtils.INCLUDE_METADATAS,
QueryUtils.INCLUDE_DOCUMENTS,
QueryUtils.INCLUDE_DISTANCES,
QueryUtils.INCLUDE_EMBEDDINGS]
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.property_descriptors
def onScheduled(self, context):
self.client = ChromaUtils.create_client(context)
self.embedding_function = EmbeddingUtils.create_embedding_function(context)
self.include_ids = context.getProperty(QueryUtils.INCLUDE_IDS).asBoolean()
self.include_metadatas = context.getProperty(QueryUtils.INCLUDE_METADATAS).asBoolean()
self.include_documents = context.getProperty(QueryUtils.INCLUDE_DOCUMENTS).asBoolean()
self.include_distances = context.getProperty(QueryUtils.INCLUDE_DISTANCES).asBoolean()
self.include_embeddings = context.getProperty(QueryUtils.INCLUDE_EMBEDDINGS).asBoolean()
self.results_field = context.getProperty(QueryUtils.RESULTS_FIELD).getValue()
self.query_utils = QueryUtils.QueryUtils(context)
def transform(self, context, flowfile):
client = self.client
embedding_function = self.embedding_function
collection_name = context.getProperty(ChromaUtils.COLLECTION_NAME).evaluateAttributeExpressions(flowfile).getValue()
collection = client.get_collection(
name=collection_name,
embedding_function=embedding_function)
query_text = context.getProperty(self.QUERY).evaluateAttributeExpressions(flowfile).getValue()
embeddings = embedding_function([query_text])
included_fields = []
if self.include_distances:
included_fields.append('distances')
if self.include_documents:
included_fields.append('documents')
if self.include_embeddings:
included_fields.append('embeddings')
if self.include_metadatas:
included_fields.append('metadatas')
where = None
where_clause = context.getProperty(self.METADATA_FILTER).evaluateAttributeExpressions(flowfile).getValue()
if where_clause is not None:
where = json.loads(where_clause)
where_document = None
where_document_clause = context.getProperty(self.DOCUMENT_FILTER).evaluateAttributeExpressions(flowfile).getValue()
if where_document_clause is not None:
where_document = json.loads(where_document_clause)
query_results = collection.query(
query_embeddings=embeddings,
n_results=context.getProperty(self.NUMBER_OF_RESULTS).evaluateAttributeExpressions(flowfile).asInteger(),
include=included_fields,
where_document=where_document,
where=where
)
ids = query_results['ids'][0]
distances = None if (not self.include_distances or query_results['distances'] is None) else query_results['distances'][0]
metadatas = None if (not self.include_metadatas or query_results['metadatas'] is None) else query_results['metadatas'][0]
documents = None if (not self.include_documents or query_results['documents'] is None) else query_results['documents'][0]
embeddings = None if (not self.include_embeddings or query_results['embeddings'] is None) else query_results['embeddings'][0]
(output_contents, mime_type) = self.query_utils.create_json(flowfile, documents, metadatas, embeddings, distances, ids)
# Return the results
attributes = {"mime.type": mime_type}
return FlowFileTransformResult(relationship = "success", contents=output_contents, attributes=attributes)

View File

@ -0,0 +1,165 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from langchain.vectorstores import Pinecone
from langchain.embeddings.openai import OpenAIEmbeddings
from nifiapi.flowfiletransform import FlowFileTransform, FlowFileTransformResult
from nifiapi.properties import PropertyDescriptor, StandardValidators, ExpressionLanguageScope
import QueryUtils
import pinecone
class QueryPinecone(FlowFileTransform):
class Java:
implements = ['org.apache.nifi.python.processor.FlowFileTransform']
class ProcessorDetails:
version = '2.0.0-SNAPSHOT'
description = "Queries Pinecone in order to gather a specified number of documents that are most closely related to the given query."
tags = ["pinecone", "vector", "vectordb", "vectorstore", "embeddings", "ai", "artificial intelligence", "ml", "machine learning", "text", "LLM"]
PINECONE_API_KEY = PropertyDescriptor(
name="Pinecone API Key",
description="The API Key to use in order to authentication with Pinecone",
sensitive=True,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR]
)
OPENAI_API_KEY = PropertyDescriptor(
name="OpenAI API Key",
description="The API Key for OpenAI in order to create embeddings",
sensitive=True,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR]
)
PINECONE_ENV = PropertyDescriptor(
name="Pinecone Environment",
description="The name of the Pinecone Environment. This can be found in the Pinecone console next to the API Key.",
sensitive=False,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR]
)
INDEX_NAME = PropertyDescriptor(
name="Index Name",
description="The name of the Pinecone index.",
sensitive=False,
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
QUERY = PropertyDescriptor(
name="Query",
description="The text of the query to send to Pinecone.",
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
NUMBER_OF_RESULTS = PropertyDescriptor(
name="Number of Results",
description="The number of results to return from Pinecone",
required=True,
validators=[StandardValidators.POSITIVE_INTEGER_VALIDATOR],
default_value="10",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
TEXT_KEY = PropertyDescriptor(
name="Text Key",
description="The key in the document that contains the text to create embeddings for.",
required=True,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
default_value="text",
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
NAMESPACE = PropertyDescriptor(
name="Namespace",
description="The name of the Pinecone Namespace to put the documents to.",
required=False,
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
expression_language_scope=ExpressionLanguageScope.FLOWFILE_ATTRIBUTES
)
properties = [PINECONE_API_KEY,
OPENAI_API_KEY,
PINECONE_ENV,
INDEX_NAME,
QUERY,
NUMBER_OF_RESULTS,
NAMESPACE,
TEXT_KEY,
QueryUtils.OUTPUT_STRATEGY,
QueryUtils.RESULTS_FIELD,
QueryUtils.INCLUDE_METADATAS,
QueryUtils.INCLUDE_DISTANCES]
embeddings = None
query_utils = None
def __init__(self, **kwargs):
pass
def getPropertyDescriptors(self):
return self.properties
def onScheduled(self, context):
api_key = context.getProperty(self.PINECONE_API_KEY).getValue()
pinecone_env = context.getProperty(self.PINECONE_ENV).getValue()
# initialize pinecone
pinecone.init(
api_key=api_key,
environment=pinecone_env,
)
openai_api_key = context.getProperty(self.OPENAI_API_KEY).getValue()
self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
self.query_utils = QueryUtils.QueryUtils(context)
def transform(self, context, flowfile):
# First, check if our index already exists. If it doesn't, we create it
index_name = context.getProperty(self.INDEX_NAME).evaluateAttributeExpressions(flowfile).getValue()
query = context.getProperty(self.QUERY).evaluateAttributeExpressions(flowfile).getValue()
namespace = context.getProperty(self.NAMESPACE).evaluateAttributeExpressions(flowfile).getValue()
num_results = context.getProperty(self.NUMBER_OF_RESULTS).evaluateAttributeExpressions(flowfile).asInteger()
index = pinecone.Index(index_name)
text_key = context.getProperty(self.TEXT_KEY).evaluateAttributeExpressions().getValue()
vectorstore = Pinecone(index, self.embeddings.embed_query, text_key, namespace=namespace)
results = vectorstore.similarity_search_with_score(query, num_results)
documents = []
for result in results:
documents.append(result[0].page_content)
if context.getProperty(QueryUtils.INCLUDE_METADATAS):
metadatas = []
for result in results:
metadatas.append(result[0].metadata)
else:
metadatas = None
if context.getProperty(QueryUtils.INCLUDE_DISTANCES):
distances = []
for result in results:
distances.append(result[1])
else:
distances = None
(output_contents, mime_type) = self.query_utils.create_json(flowfile, documents, metadatas, None, distances, None)
attributes = {"mime.type": mime_type}
return FlowFileTransformResult(relationship = "success", contents=output_contents, attributes=attributes)

View File

@ -0,0 +1,188 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from nifiapi.properties import PropertyDescriptor, StandardValidators, ExpressionLanguageScope, PropertyDependency
import json
ROW_ORIENTED = "Row-Oriented"
TEXT = "Text"
COLUMN_ORIENTED = "Column-Oriented"
OUTPUT_STRATEGY = PropertyDescriptor(
name="Output Strategy",
description="Specifies whether the output should contain only the text of the documents (each document separated by \\n\\n), or if it " +
"should be formatted as either single column-oriented JSON object, " +
"consisting of a keys 'ids', 'embeddings', 'documents', 'distances', and 'metadatas'; or if the results should be row-oriented, " +
"a JSON per line, each consisting of a single id, document, metadata, embedding, and distance.",
allowable_values=[ROW_ORIENTED, TEXT, COLUMN_ORIENTED],
default_value=ROW_ORIENTED,
required=True
)
RESULTS_FIELD = PropertyDescriptor(
name="Results Field",
description="If the input FlowFile is JSON Formatted, this represents the name of the field to insert the results. This allows the results to be inserted into " +
"an existing input in order to enrich it. If this property is unset, the results will be written to the FlowFile contents, overwriting any pre-existing content.",
validators=[StandardValidators.NON_EMPTY_VALIDATOR],
required=False
)
INCLUDE_IDS = PropertyDescriptor(
name="Include Document IDs",
description="Whether or not to include the Documents' IDs in the response",
allowable_values=["true", "false"],
default_value="true",
required=False,
dependencies=[PropertyDependency(OUTPUT_STRATEGY, ROW_ORIENTED, COLUMN_ORIENTED)]
)
INCLUDE_METADATAS = PropertyDescriptor(
name="Include Metadata",
description="Whether or not to include the Documents' Metadata in the response",
allowable_values=["true", "false"],
default_value="true",
required=False,
dependencies=[PropertyDependency(OUTPUT_STRATEGY, ROW_ORIENTED, COLUMN_ORIENTED)]
)
INCLUDE_DOCUMENTS = PropertyDescriptor(
name="Include Document",
description="Whether or not to include the Documents' Text in the response",
allowable_values=["true", "false"],
default_value="true",
required=False,
dependencies=[PropertyDependency(OUTPUT_STRATEGY, ROW_ORIENTED, COLUMN_ORIENTED)]
)
INCLUDE_DISTANCES = PropertyDescriptor(
name="Include Distances",
description="Whether or not to include the Documents' Distances (i.e., how far the Document was away from the query) in the response",
allowable_values=["true", "false"],
default_value="true",
required=False,
dependencies=[PropertyDependency(OUTPUT_STRATEGY, ROW_ORIENTED, COLUMN_ORIENTED)]
)
INCLUDE_EMBEDDINGS = PropertyDescriptor(
name="Include Embeddings",
description="Whether or not to include the Documents' Embeddings in the response",
allowable_values=["true", "false"],
default_value="false",
required=False,
dependencies=[PropertyDependency(OUTPUT_STRATEGY, ROW_ORIENTED, COLUMN_ORIENTED)]
)
class QueryUtils:
context = None
def __init__(self, context):
self.context = context
self.results_field = context.getProperty(RESULTS_FIELD).getValue()
self.output_strategy = context.getProperty(OUTPUT_STRATEGY).getValue()
ids_property = context.getProperty(INCLUDE_IDS)
self.include_ids = ids_property.asBoolean() if ids_property else False
embeddings_property = context.getProperty(INCLUDE_EMBEDDINGS)
self.include_embeddings = embeddings_property.asBoolean() if embeddings_property else False
self.include_distances = context.getProperty(INCLUDE_DISTANCES).asBoolean()
documents_property = context.getProperty(INCLUDE_DOCUMENTS)
self.include_documents = documents_property.asBoolean() if documents_property else True
self.include_metadatas = context.getProperty(INCLUDE_METADATAS).asBoolean()
def create_json(self, flowfile, documents, metadatas, embeddings, distances, ids) -> Tuple[str, str]:
if self.results_field is None:
input_json = None
else:
input_json = json.loads(flowfile.getContentsAsBytes().decode())
if self.output_strategy == TEXT:
# Delete any document that is None or an empty-string
documents = [doc for doc in documents if doc is not None and doc != ""]
# Join the documents with two newlines
text = "\n\n".join(documents)
# Create either JSON or text output, based on whether or not an results field was specified
if input_json is None:
mime_type = "text/plain"
output_contents = text
else:
input_json[self.results_field] = text
output_contents = json.dumps(input_json)
mime_type = "application/json"
elif self.output_strategy == COLUMN_ORIENTED:
doc = {}
if self.include_ids:
doc['ids'] = ids
if self.include_distances:
doc['distances'] = distances
if self.include_documents:
doc['documents'] = documents
if self.include_metadatas:
doc['metadatas'] = metadatas
if self.include_embeddings:
doc['embeddings'] = embeddings
# Create the JSON from the Document
if input_json is None:
output_contents = json.dumps(doc)
else:
input_json[self.results_field] = doc
output_contents = json.dumps(input_json)
mime_type = "application/json"
else:
# Build the Documents
docs = []
count = len(ids) if ids else len(documents)
for i in range(count):
id = None if ids is None else ids[i]
distance = None if distances is None else distances[i]
metadata = None if metadatas is None else metadatas[i]
document = None if documents is None else documents[i]
embedding = None if embeddings is None else embeddings[i]
# Create the document but do not include any key that we don't want to include in the output.
doc = {}
if self.include_ids:
doc['id'] = id
if self.include_distances:
doc['distance'] = distance
if self.include_documents:
doc['document'] = document
if self.include_metadatas:
doc['metadata'] = metadata
if self.include_embeddings:
doc['embedding'] = embedding
docs.append(doc)
# If input_json is None, we just create JSON based on the Documents.
# If input_json is populated, we insert the documents into the input JSON using the specified key.
if input_json is None:
jsons = []
for doc in docs:
jsons.append(json.dumps(doc))
output_contents = "\n".join(jsons)
else:
input_json[self.results_field] = docs
output_contents = json.dumps(input_json)
mime_type = "application/json"
return output_contents, mime_type

View File

@ -0,0 +1,14 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,29 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Shared requirements
openai
# Chroma requirements
chromadb==0.4.14
onnxruntime
tokenizers
tqdm
requests
# Pinecone requirements
pinecone-client
tiktoken
langchain

View File

@ -0,0 +1,36 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nifi</artifactId>
<groupId>org.apache.nifi</groupId>
<version>2.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<artifactId>nifi-python-extensions</artifactId>
<modules>
<module>nifi-text-embeddings-module</module>
<module>nifi-openai-module</module>
<module>nifi-python-extensions-bundle</module>
</modules>
</project>

View File

@ -42,6 +42,7 @@
<module>nifi-registry</module>
<module>nifi-toolkit</module>
<module>nifi-manifest</module>
<module>nifi-python-extensions</module>
<module>c2</module>
</modules>
<url>https://nifi.apache.org</url>