diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml
index c40c81ae98..a2ab512cd4 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml
@@ -65,6 +65,22 @@
software.amazon.awssdk
firehose
+
+ software.amazon.awssdk
+ polly
+
+
+ software.amazon.awssdk
+ textract
+
+
+ software.amazon.awssdk
+ transcribe
+
+
+ software.amazon.awssdk
+ translate
+
software.amazon.kinesis
amazon-kinesis-client
@@ -130,6 +146,10 @@
com.github.ben-manes.caffeine
caffeine
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jsr310
+
org.bouncycastle
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStarter.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStarter.java
similarity index 66%
rename from nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStarter.java
rename to nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStarter.java
index 27f0664370..f1c0ab1fb3 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStarter.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStarter.java
@@ -17,14 +17,6 @@
package org.apache.nifi.processors.aws.ml;
-import com.amazonaws.AmazonWebServiceClient;
-import com.amazonaws.AmazonWebServiceRequest;
-import com.amazonaws.AmazonWebServiceResult;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -33,12 +25,18 @@ import org.apache.commons.io.IOUtils;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
+import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
-import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor;
+import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
+import software.amazon.awssdk.awscore.AwsRequest;
+import software.amazon.awssdk.awscore.AwsResponse;
+import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
+import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
+import software.amazon.awssdk.core.SdkClient;
import java.io.IOException;
import java.io.InputStream;
@@ -46,10 +44,15 @@ import java.util.List;
import java.util.Set;
import static org.apache.nifi.flowfile.attributes.CoreAttributes.MIME_TYPE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
-public abstract class AwsMachineLearningJobStarter
- extends AbstractAWSCredentialsProviderProcessor {
+public abstract class AbstractAwsMachineLearningJobStarter<
+ Q extends AwsRequest,
+ B extends AwsRequest.Builder,
+ R extends AwsResponse,
+ T extends SdkClient,
+ U extends AwsSyncClientBuilder & AwsClientBuilder>
+ extends AbstractAwsSyncProcessor {
public static final PropertyDescriptor JSON_PAYLOAD = new PropertyDescriptor.Builder()
.name("json-payload")
.displayName("JSON Payload")
@@ -62,18 +65,17 @@ public abstract class AwsMachineLearningJobStarter PROPERTIES = List.of(
MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE,
REGION,
@@ -84,10 +86,9 @@ public abstract class AwsMachineLearningJobStarter relationships = Set.of(REL_ORIGINAL,
- REL_SUCCESS,
- REL_FAILURE);
+ private static final Set relationships = Set.of(REL_ORIGINAL, REL_SUCCESS, REL_FAILURE);
@Override
public Set getRelationships() {
@@ -105,14 +106,14 @@ public abstract class AwsMachineLearningJobStarter MAPPER.writeValue(out, response));
+ childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
return childFlowFile;
}
@@ -156,7 +152,7 @@ public abstract class AwsMachineLearningJobStarter getAwsRequestClass(ProcessContext context, FlowFile flowFile);
+ abstract protected Class extends B> getAwsRequestBuilderClass(ProcessContext context, FlowFile flowFile);
- abstract protected String getAwsTaskId(ProcessContext context, RESPONSE response, FlowFile flowFile);
+ abstract protected String getAwsTaskId(ProcessContext context, R response, FlowFile flowFile);
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStatusProcessor.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStatusProcessor.java
similarity index 69%
rename from nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStatusProcessor.java
rename to nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStatusProcessor.java
index 5a5f8f02e2..8cc555b46a 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStatusProcessor.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStatusProcessor.java
@@ -17,33 +17,30 @@
package org.apache.nifi.processors.aws.ml;
-import com.amazonaws.AmazonWebServiceClient;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.ResponseMetadata;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.http.SdkHttpMetadata;
-import com.amazonaws.regions.Region;
-import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
-import com.fasterxml.jackson.databind.module.SimpleModule;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.flowfile.FlowFile;
-import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.util.StandardValidators;
-import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor;
+import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
+import software.amazon.awssdk.awscore.AwsResponse;
+import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
+import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
+import software.amazon.awssdk.core.SdkClient;
import java.util.List;
import java.util.Set;
import static org.apache.nifi.expression.ExpressionLanguageScope.FLOWFILE_ATTRIBUTES;
-public abstract class AwsMachineLearningJobStatusProcessor
- extends AbstractAWSCredentialsProviderProcessor {
+public abstract class AbstractAwsMachineLearningJobStatusProcessor<
+ T extends SdkClient,
+ U extends AwsSyncClientBuilder & AwsClientBuilder>
+ extends AbstractAwsSyncProcessor {
public static final String AWS_TASK_OUTPUT_LOCATION = "outputLocation";
public static final PropertyDescriptor MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE =
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
@@ -81,13 +78,6 @@ public abstract class AwsMachineLearningJobStatusProcessor PROPERTIES = List.of(
TASK_ID,
@@ -99,18 +89,14 @@ public abstract class AwsMachineLearningJobStatusProcessor getRelationships() {
return relationships;
@@ -129,13 +115,7 @@ public abstract class AwsMachineLearningJobStatusProcessor MAPPER.writeValue(out, response));
+ protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final AwsResponse response) {
+ return session.write(flowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml
index 627b4f8b7f..c0748af598 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml
@@ -125,22 +125,6 @@
org.apache.nifi
nifi-schema-registry-service-api
-
- com.amazonaws
- aws-java-sdk-translate
-
-
- com.amazonaws
- aws-java-sdk-polly
-
-
- com.amazonaws
- aws-java-sdk-transcribe
-
-
- com.amazonaws
- aws-java-sdk-textract
-
org.bouncycastle
bcprov-jdk18on
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java
index 73d6bfa148..e2325d8fc2 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java
@@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.polly;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.polly.AmazonPollyClient;
-import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
-import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
-import com.amazonaws.services.polly.model.TaskStatus;
-import com.amazonaws.services.textract.model.ThrottlingException;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
@@ -34,9 +25,17 @@ import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
+import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
+import software.amazon.awssdk.services.polly.PollyClient;
+import software.amazon.awssdk.services.polly.PollyClientBuilder;
+import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
+import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
+import software.amazon.awssdk.services.polly.model.TaskStatus;
+import java.util.HashSet;
+import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -45,10 +44,10 @@ import java.util.regex.Pattern;
@SeeAlso({StartAwsPollyJob.class})
@WritesAttributes({
@WritesAttribute(attribute = "PollyS3OutputBucket", description = "The bucket name where polly output will be located."),
- @WritesAttribute(attribute = "PollyS3OutputKey", description = "Object key of polly output."),
+ @WritesAttribute(attribute = "filename", description = "Object key of polly output."),
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
})
-public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor {
+public class GetAwsPollyJobStatus extends AbstractAwsMachineLearningJobStatusProcessor {
private static final String BUCKET = "bucket";
private static final String KEY = "key";
private static final Pattern S3_PATH = Pattern.compile("https://s3.*amazonaws.com/(?<" + BUCKET + ">[^/]+)/(?<" + KEY + ">.*)");
@@ -56,65 +55,66 @@ public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor getRelationships() {
+ final Set parentRelationships = new HashSet<>(super.getRelationships());
+ parentRelationships.remove(REL_THROTTLED);
+ return Set.copyOf(parentRelationships);
}
@Override
- public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
+ protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
+ return PollyClient.builder();
+ }
+
+ @Override
+ public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
- GetSpeechSynthesisTaskResult speechSynthesisTask;
+ final GetSpeechSynthesisTaskResponse speechSynthesisTask;
try {
speechSynthesisTask = getSynthesisTask(context, flowFile);
- } catch (ThrottlingException e) {
- getLogger().info("Request Rate Limit exceeded", e);
- session.transfer(flowFile, REL_THROTTLED);
- return;
- } catch (Exception e) {
+ } catch (final Exception e) {
getLogger().warn("Failed to get Polly Job status", e);
session.transfer(flowFile, REL_FAILURE);
return;
}
- TaskStatus taskStatus = TaskStatus.fromValue(speechSynthesisTask.getSynthesisTask().getTaskStatus());
+ final TaskStatus taskStatus = speechSynthesisTask.synthesisTask().taskStatus();
- if (taskStatus == TaskStatus.InProgress || taskStatus == TaskStatus.Scheduled) {
- session.penalize(flowFile);
+ if (taskStatus == TaskStatus.IN_PROGRESS || taskStatus == TaskStatus.SCHEDULED) {
+ flowFile = session.penalize(flowFile);
session.transfer(flowFile, REL_RUNNING);
- } else if (taskStatus == TaskStatus.Completed) {
- String outputUri = speechSynthesisTask.getSynthesisTask().getOutputUri();
+ } else if (taskStatus == TaskStatus.COMPLETED) {
+ final String outputUri = speechSynthesisTask.synthesisTask().outputUri();
- Matcher matcher = S3_PATH.matcher(outputUri);
+ final Matcher matcher = S3_PATH.matcher(outputUri);
if (matcher.find()) {
- session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
- session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
+ flowFile = session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
+ flowFile = session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
}
FlowFile childFlowFile = session.create(flowFile);
- writeToFlowFile(session, childFlowFile, speechSynthesisTask);
+ childFlowFile = writeToFlowFile(session, childFlowFile, speechSynthesisTask);
childFlowFile = session.putAttribute(childFlowFile, AWS_TASK_OUTPUT_LOCATION, outputUri);
session.transfer(flowFile, REL_ORIGINAL);
session.transfer(childFlowFile, REL_SUCCESS);
getLogger().info("Amazon Polly Task Completed {}", flowFile);
- } else if (taskStatus == TaskStatus.Failed) {
- final String failureReason = speechSynthesisTask.getSynthesisTask().getTaskStatusReason();
+ } else if (taskStatus == TaskStatus.FAILED) {
+ final String failureReason = speechSynthesisTask.synthesisTask().taskStatusReason();
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Polly Task Failed {} Reason [{}]", flowFile, failureReason);
+ } else if (taskStatus == TaskStatus.UNKNOWN_TO_SDK_VERSION) {
+ flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
+ session.transfer(flowFile, REL_FAILURE);
+ getLogger().error("Amazon Polly Task Failed {} Reason [Unrecognized job status]", flowFile);
}
}
- private GetSpeechSynthesisTaskResult getSynthesisTask(ProcessContext context, FlowFile flowFile) {
- String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
- GetSpeechSynthesisTaskRequest request = new GetSpeechSynthesisTaskRequest().withTaskId(taskId);
+ private GetSpeechSynthesisTaskResponse getSynthesisTask(final ProcessContext context, final FlowFile flowFile) {
+ final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
+ final GetSpeechSynthesisTaskRequest request = GetSpeechSynthesisTaskRequest.builder().taskId(taskId).build();
return getClient(context).getSpeechSynthesisTask(request);
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java
index 0ab52ac597..8052b92558 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java
@@ -17,49 +17,45 @@
package org.apache.nifi.processors.aws.ml.polly;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.polly.AmazonPollyClient;
-import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskRequest;
-import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskResult;
+import org.apache.nifi.annotation.behavior.WritesAttribute;
+import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
+import software.amazon.awssdk.services.polly.PollyClient;
+import software.amazon.awssdk.services.polly.PollyClientBuilder;
+import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
+import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Polly"})
@CapabilityDescription("Trigger a AWS Polly job. It should be followed by GetAwsPollyJobStatus processor in order to monitor job status.")
+@WritesAttributes({
+ @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsPollyJobStatus")
+})
@SeeAlso({GetAwsPollyJobStatus.class})
-public class StartAwsPollyJob extends AwsMachineLearningJobStarter {
+public class StartAwsPollyJob extends AbstractAwsMachineLearningJobStarter<
+ StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskRequest.Builder, StartSpeechSynthesisTaskResponse, PollyClient, PollyClientBuilder> {
@Override
- protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
-
- return (AmazonPollyClient) AmazonPollyClient.builder()
- .withRegion(context.getProperty(REGION).getValue())
- .withCredentials(credentialsProvider)
- .withClientConfiguration(config)
- .withEndpointConfiguration(endpointConfiguration)
- .build();
+ protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
+ return PollyClient.builder();
}
@Override
- protected StartSpeechSynthesisTaskResult sendRequest(StartSpeechSynthesisTaskRequest request, ProcessContext context, FlowFile flowFile) {
+ protected StartSpeechSynthesisTaskResponse sendRequest(final StartSpeechSynthesisTaskRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startSpeechSynthesisTask(request);
}
@Override
- protected Class extends StartSpeechSynthesisTaskRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
- return StartSpeechSynthesisTaskRequest.class;
+ protected Class extends StartSpeechSynthesisTaskRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
+ return StartSpeechSynthesisTaskRequest.serializableBuilderClass();
}
@Override
- protected String getAwsTaskId(ProcessContext context, StartSpeechSynthesisTaskResult startSpeechSynthesisTaskResult, FlowFile flowFile) {
- return startSpeechSynthesisTaskResult.getSynthesisTask().getTaskId();
+ protected String getAwsTaskId(final ProcessContext context, final StartSpeechSynthesisTaskResponse startSpeechSynthesisTaskResponse, final FlowFile flowFile) {
+ return startSpeechSynthesisTaskResponse.synthesisTask().taskId();
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java
index ce48142f0c..2687735217 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java
@@ -17,48 +17,61 @@
package org.apache.nifi.processors.aws.ml.textract;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.textract.AmazonTextractClient;
-import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
-import com.amazonaws.services.textract.model.GetDocumentTextDetectionRequest;
-import com.amazonaws.services.textract.model.GetExpenseAnalysisRequest;
-import com.amazonaws.services.textract.model.JobStatus;
-import com.amazonaws.services.textract.model.ThrottlingException;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.components.PropertyDescriptor;
+import org.apache.nifi.components.ValidationContext;
+import org.apache.nifi.components.ValidationResult;
+import org.apache.nifi.components.Validator;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
-import org.apache.nifi.processor.util.StandardValidators;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
+import software.amazon.awssdk.services.textract.TextractClient;
+import software.amazon.awssdk.services.textract.TextractClientBuilder;
+import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
+import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.JobStatus;
+import software.amazon.awssdk.services.textract.model.TextractResponse;
+import software.amazon.awssdk.services.textract.model.ThrottlingException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
-import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
+import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
@CapabilityDescription("Retrieves the current status of an AWS Textract job.")
@SeeAlso({StartAwsTextractJob.class})
-public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcessor {
+public class GetAwsTextractJobStatus extends AbstractAwsMachineLearningJobStatusProcessor {
+
+ public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
+ @Override
+ public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
+ if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
+ return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
+ } else if (TextractType.TEXTRACT_TYPES.contains(value)) {
+ return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
+ } else {
+ return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
+ }
+ }
+ };
+
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
.name("textract-type")
.displayName("Textract Type")
.required(true)
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
- .allowableValues(TextractType.TEXTRACT_TYPES)
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
- .defaultValue(DOCUMENT_ANALYSIS.getType())
- .addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
+ .defaultValue(String.format("${%s}", TEXTRACT_TYPE_ATTRIBUTE))
+ .addValidator(TEXTRACT_TYPE_VALIDATOR)
.build();
private static final List TEXTRACT_PROPERTIES =
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
@@ -69,30 +82,24 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
}
@Override
- protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
- return (AmazonTextractClient) AmazonTextractClient.builder()
- .withRegion(context.getProperty(REGION).getValue())
- .withClientConfiguration(config)
- .withEndpointConfiguration(endpointConfiguration)
- .withCredentials(credentialsProvider)
- .build();
+ protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
+ return TextractClient.builder();
}
@Override
- public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
+ public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
- String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
+ final String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
- String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
+ final String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
try {
- JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
+ final JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
if (JobStatus.SUCCEEDED == jobStatus) {
- Object task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
- writeToFlowFile(session, flowFile, task);
+ final TextractResponse task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
+ flowFile = writeToFlowFile(session, flowFile, task);
session.transfer(flowFile, REL_SUCCESS);
} else if (JobStatus.IN_PROGRESS == jobStatus) {
session.transfer(flowFile, REL_RUNNING);
@@ -101,29 +108,31 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
} else if (JobStatus.FAILED == jobStatus) {
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Textract Task [{}] Failed", awsTaskId);
+ } else {
+ throw new IllegalStateException("Unrecognized job status");
}
- } catch (ThrottlingException e) {
+ } catch (final ThrottlingException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
- } catch (Exception e) {
+ } catch (final Exception e) {
getLogger().warn("Failed to get Textract Job status", e);
session.transfer(flowFile, REL_FAILURE);
}
}
- private Object getTask(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) {
+ private TextractResponse getTask(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
return switch (typeOfTextract) {
- case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId));
- case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId));
- case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId));
+ case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build());
+ case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build());
+ case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build());
};
}
- private JobStatus getTaskStatus(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) {
+ private JobStatus getTaskStatus(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
return switch (typeOfTextract) {
- case DOCUMENT_ANALYSIS -> JobStatus.fromValue(client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)).getJobStatus());
- case DOCUMENT_TEXT_DETECTION -> JobStatus.fromValue(client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)).getJobStatus());
- case EXPENSE_ANALYSIS -> JobStatus.fromValue(client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)).getJobStatus());
+ case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
+ case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build()).jobStatus();
+ case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
};
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java
index 70ff2391a8..301aadc880 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java
@@ -17,31 +17,26 @@
package org.apache.nifi.processors.aws.ml.textract;
-import com.amazonaws.AmazonWebServiceRequest;
-import com.amazonaws.AmazonWebServiceResult;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.textract.AmazonTextractClient;
-import com.amazonaws.services.textract.model.StartDocumentAnalysisRequest;
-import com.amazonaws.services.textract.model.StartDocumentAnalysisResult;
-import com.amazonaws.services.textract.model.StartDocumentTextDetectionRequest;
-import com.amazonaws.services.textract.model.StartDocumentTextDetectionResult;
-import com.amazonaws.services.textract.model.StartExpenseAnalysisRequest;
-import com.amazonaws.services.textract.model.StartExpenseAnalysisResult;
+import org.apache.nifi.annotation.behavior.WritesAttribute;
+import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.components.PropertyDescriptor;
-import org.apache.nifi.components.ValidationContext;
-import org.apache.nifi.components.ValidationResult;
-import org.apache.nifi.components.Validator;
-import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
+import software.amazon.awssdk.services.textract.TextractClient;
+import software.amazon.awssdk.services.textract.TextractClientBuilder;
+import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
+import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
+import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
+import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
+import software.amazon.awssdk.services.textract.model.TextractRequest;
+import software.amazon.awssdk.services.textract.model.TextractResponse;
import java.util.Collections;
import java.util.List;
@@ -52,28 +47,23 @@ import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_A
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
@CapabilityDescription("Trigger a AWS Textract job. It should be followed by GetAwsTextractJobStatus processor in order to monitor job status.")
+@WritesAttributes({
+ @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTextractJobStatus"),
+ @WritesAttribute(attribute = "awsTextractType", description = "The selected Textract type, which can be used in GetAwsTextractJobStatus")
+})
@SeeAlso({GetAwsTextractJobStatus.class})
-public class StartAwsTextractJob extends AwsMachineLearningJobStarter {
- public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
- @Override
- public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
- if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
- return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
- } else if (TextractType.TEXTRACT_TYPES.contains(value)) {
- return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
- } else {
- return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
- }
- }
- };
+public class StartAwsTextractJob extends AbstractAwsMachineLearningJobStarter<
+ TextractRequest, TextractRequest.Builder, TextractResponse, TextractClient, TextractClientBuilder> {
+
+ public static final String TEXTRACT_TYPE_ATTRIBUTE = "awsTextractType";
+
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
.name("textract-type")
.displayName("Textract Type")
.required(true)
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
- .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
- .defaultValue(DOCUMENT_ANALYSIS.type)
- .addValidator(TEXTRACT_TYPE_VALIDATOR)
+ .allowableValues(TextractType.TEXTRACT_TYPES)
+ .defaultValue(DOCUMENT_ANALYSIS.getType())
.build();
private static final List TEXTRACT_PROPERTIES =
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
@@ -84,24 +74,13 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter getClient(context).startDocumentAnalysis((StartDocumentAnalysisRequest) request);
case DOCUMENT_TEXT_DETECTION -> getClient(context).startDocumentTextDetection((StartDocumentTextDetectionRequest) request);
@@ -110,22 +89,28 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
- final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
+ protected Class extends TextractRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
+ final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (typeOfTextract) {
- case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.class;
- case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.class;
- case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.class;
+ case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.serializableBuilderClass();
+ case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.serializableBuilderClass();
+ case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.serializableBuilderClass();
};
}
@Override
- protected String getAwsTaskId(ProcessContext context, AmazonWebServiceResult amazonWebServiceResult, FlowFile flowFile) {
- final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
+ protected String getAwsTaskId(final ProcessContext context, final TextractResponse textractResponse, final FlowFile flowFile) {
+ final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (textractType) {
- case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResult) amazonWebServiceResult).getJobId();
- case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResult) amazonWebServiceResult).getJobId();
- case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResult) amazonWebServiceResult).getJobId();
+ case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResponse) textractResponse).jobId();
+ case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResponse) textractResponse).jobId();
+ case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResponse) textractResponse).jobId();
};
}
+
+ @Override
+ protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, FlowFile flowFile, final TextractResponse response) {
+ flowFile = super.postProcessFlowFile(context, session, flowFile, response);
+ return session.putAttribute(flowFile, TEXTRACT_TYPE_ATTRIBUTE, context.getProperty(TEXTRACT_TYPE).getValue());
+ }
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java
index ff48af3843..9ad549b75e 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java
@@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.transcribe;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.textract.model.ThrottlingException;
-import com.amazonaws.services.transcribe.AmazonTranscribeClient;
-import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
-import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
-import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
@@ -35,7 +26,13 @@ import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
+import software.amazon.awssdk.services.transcribe.TranscribeClient;
+import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
+import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
+import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
+import software.amazon.awssdk.services.transcribe.model.LimitExceededException;
+import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
@CapabilityDescription("Retrieves the current status of an AWS Transcribe job.")
@@ -43,55 +40,50 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
@WritesAttributes({
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
})
-public class GetAwsTranscribeJobStatus extends AwsMachineLearningJobStatusProcessor {
+public class GetAwsTranscribeJobStatus extends AbstractAwsMachineLearningJobStatusProcessor {
@Override
- protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
- return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
- .withRegion(context.getProperty(REGION).getValue())
- .withCredentials(credentialsProvider)
- .withEndpointConfiguration(endpointConfiguration)
- .withClientConfiguration(config)
- .build();
+ protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
+ return TranscribeClient.builder();
}
@Override
- public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
+ public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
try {
- GetTranscriptionJobResult job = getJob(context, flowFile);
- TranscriptionJobStatus jobStatus = TranscriptionJobStatus.fromValue(job.getTranscriptionJob().getTranscriptionJobStatus());
+ final GetTranscriptionJobResponse job = getJob(context, flowFile);
+ final TranscriptionJobStatus status = job.transcriptionJob().transcriptionJobStatus();
- if (TranscriptionJobStatus.COMPLETED == jobStatus) {
- writeToFlowFile(session, flowFile, job);
- session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.getTranscriptionJob().getTranscript().getTranscriptFileUri());
+ if (TranscriptionJobStatus.COMPLETED == status) {
+ flowFile = writeToFlowFile(session, flowFile, job);
+ flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.transcriptionJob().transcript().transcriptFileUri());
session.transfer(flowFile, REL_SUCCESS);
- } else if (TranscriptionJobStatus.IN_PROGRESS == jobStatus) {
+ } else if (TranscriptionJobStatus.IN_PROGRESS == status || TranscriptionJobStatus.QUEUED == status) {
session.transfer(flowFile, REL_RUNNING);
- } else if (TranscriptionJobStatus.FAILED == jobStatus) {
- final String failureReason = job.getTranscriptionJob().getFailureReason();
- session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
+ } else if (TranscriptionJobStatus.FAILED == status) {
+ final String failureReason = job.transcriptionJob().failureReason();
+ flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Transcribe Task Failed {} Reason [{}]", flowFile, failureReason);
+ } else {
+ flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
+ throw new IllegalStateException("Unrecognized job status");
}
- } catch (ThrottlingException e) {
+ } catch (final LimitExceededException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
- return;
- } catch (Exception e) {
+ } catch (final Exception e) {
getLogger().warn("Failed to get Transcribe Job status", e);
session.transfer(flowFile, REL_FAILURE);
- return;
}
}
- private GetTranscriptionJobResult getJob(ProcessContext context, FlowFile flowFile) {
- String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
- GetTranscriptionJobRequest request = new GetTranscriptionJobRequest().withTranscriptionJobName(taskId);
+ private GetTranscriptionJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
+ final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
+ final GetTranscriptionJobRequest request = GetTranscriptionJobRequest.builder().transcriptionJobName(taskId).build();
return getClient(context).getTranscriptionJob(request);
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java
index a91192e5c6..5af6423224 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java
@@ -17,48 +17,45 @@
package org.apache.nifi.processors.aws.ml.transcribe;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.transcribe.AmazonTranscribeClient;
-import com.amazonaws.services.transcribe.model.StartTranscriptionJobRequest;
-import com.amazonaws.services.transcribe.model.StartTranscriptionJobResult;
+import org.apache.nifi.annotation.behavior.WritesAttribute;
+import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
+import software.amazon.awssdk.services.transcribe.TranscribeClient;
+import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
+import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
+import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
@CapabilityDescription("Trigger a AWS Transcribe job. It should be followed by GetAwsTranscribeStatus processor in order to monitor job status.")
+@WritesAttributes({
+ @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranscribeJobStatus")
+})
@SeeAlso({GetAwsTranscribeJobStatus.class})
-public class StartAwsTranscribeJob extends AwsMachineLearningJobStarter {
+public class StartAwsTranscribeJob extends AbstractAwsMachineLearningJobStarter<
+ StartTranscriptionJobRequest, StartTranscriptionJobRequest.Builder, StartTranscriptionJobResponse, TranscribeClient, TranscribeClientBuilder> {
@Override
- protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
- return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
- .withRegion(context.getProperty(REGION).getValue())
- .withClientConfiguration(config)
- .withEndpointConfiguration(endpointConfiguration)
- .withCredentials(credentialsProvider)
- .build();
+ protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
+ return TranscribeClient.builder();
}
@Override
- protected StartTranscriptionJobResult sendRequest(StartTranscriptionJobRequest request, ProcessContext context, FlowFile flowFile) {
+ protected StartTranscriptionJobResponse sendRequest(final StartTranscriptionJobRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startTranscriptionJob(request);
}
@Override
- protected Class extends StartTranscriptionJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
- return StartTranscriptionJobRequest.class;
+ protected Class extends StartTranscriptionJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
+ return StartTranscriptionJobRequest.serializableBuilderClass();
}
@Override
- protected String getAwsTaskId(ProcessContext context, StartTranscriptionJobResult startTranscriptionJobResult, FlowFile flowFile) {
- return startTranscriptionJobResult.getTranscriptionJob().getTranscriptionJobName();
+ protected String getAwsTaskId(final ProcessContext context, final StartTranscriptionJobResponse startTranscriptionJobResponse, final FlowFile flowFile) {
+ return startTranscriptionJobResponse.transcriptionJob().transcriptionJobName();
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java
index bc52a23efd..18c8dbaeea 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java
@@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.translate;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.textract.model.ThrottlingException;
-import com.amazonaws.services.translate.AmazonTranslateClient;
-import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
-import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
-import com.amazonaws.services.translate.model.JobStatus;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
@@ -34,8 +25,15 @@ import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
+import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
+import software.amazon.awssdk.services.translate.TranslateClient;
+import software.amazon.awssdk.services.translate.TranslateClientBuilder;
+import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
+import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
+import software.amazon.awssdk.services.translate.model.JobStatus;
+import software.amazon.awssdk.services.translate.model.LimitExceededException;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
@CapabilityDescription("Retrieves the current status of an AWS Translate job.")
@@ -43,54 +41,66 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
@WritesAttributes({
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
})
-public class GetAwsTranslateJobStatus extends AwsMachineLearningJobStatusProcessor {
+public class GetAwsTranslateJobStatus extends AbstractAwsMachineLearningJobStatusProcessor {
@Override
- protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
- return (AmazonTranslateClient) AmazonTranslateClient.builder()
- .withRegion(context.getProperty(REGION).getValue())
- .withCredentials(credentialsProvider)
- .withClientConfiguration(config)
- .withEndpointConfiguration(endpointConfiguration)
- .build();
+ protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
+ return TranslateClient.builder();
}
@Override
- public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
+ public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
- String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
try {
- DescribeTextTranslationJobResult describeTextTranslationJobResult = getStatusString(context, awsTaskId);
- JobStatus status = JobStatus.fromValue(describeTextTranslationJobResult.getTextTranslationJobProperties().getJobStatus());
+ final DescribeTextTranslationJobResponse job = getJob(context, flowFile);
+ final JobStatus status = job.textTranslationJobProperties().jobStatus();
- if (status == JobStatus.IN_PROGRESS || status == JobStatus.SUBMITTED) {
- writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
- session.penalize(flowFile);
- session.transfer(flowFile, REL_RUNNING);
- } else if (status == JobStatus.COMPLETED) {
- session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, describeTextTranslationJobResult.getTextTranslationJobProperties().getOutputDataConfig().getS3Uri());
- writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
- session.transfer(flowFile, REL_SUCCESS);
- } else if (status == JobStatus.FAILED || status == JobStatus.COMPLETED_WITH_ERROR) {
- writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
- session.transfer(flowFile, REL_FAILURE);
+ flowFile = writeToFlowFile(session, flowFile, job);
+ final Relationship transferRelationship;
+ String failureReason = null;
+ switch (status) {
+ case IN_PROGRESS:
+ case SUBMITTED:
+ case STOP_REQUESTED:
+ flowFile = session.penalize(flowFile);
+ transferRelationship = REL_RUNNING;
+ break;
+ case COMPLETED:
+ flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.textTranslationJobProperties().outputDataConfig().s3Uri());
+ transferRelationship = REL_SUCCESS;
+ break;
+ case FAILED:
+ case COMPLETED_WITH_ERROR:
+ failureReason = job.textTranslationJobProperties().message();
+ transferRelationship = REL_FAILURE;
+ break;
+ case STOPPED:
+ failureReason = String.format("Job [%s] is stopped", job.textTranslationJobProperties().jobId());
+ transferRelationship = REL_FAILURE;
+ break;
+ default:
+ failureReason = "Unknown Job Status";
+ transferRelationship = REL_FAILURE;
}
- } catch (ThrottlingException e) {
+ if (failureReason != null) {
+ flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
+ }
+ session.transfer(flowFile, transferRelationship);
+ } catch (final LimitExceededException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
- } catch (Exception e) {
- getLogger().warn("Failed to get Polly Job status", e);
+ } catch (final Exception e) {
+ getLogger().warn("Failed to get Translate Job status", e);
session.transfer(flowFile, REL_FAILURE);
}
}
- private DescribeTextTranslationJobResult getStatusString(ProcessContext context, String awsTaskId) {
- DescribeTextTranslationJobRequest request = new DescribeTextTranslationJobRequest().withJobId(awsTaskId);
- DescribeTextTranslationJobResult translationJobsResult = getClient(context).describeTextTranslationJob(request);
- return translationJobsResult;
+ private DescribeTextTranslationJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
+ final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
+ final DescribeTextTranslationJobRequest request = DescribeTextTranslationJobRequest.builder().jobId(taskId).build();
+ return getClient(context).describeTextTranslationJob(request);
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java
index 3e7f079df9..db75233b09 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java
@@ -17,47 +17,44 @@
package org.apache.nifi.processors.aws.ml.translate;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.translate.AmazonTranslateClient;
-import com.amazonaws.services.translate.model.StartTextTranslationJobRequest;
-import com.amazonaws.services.translate.model.StartTextTranslationJobResult;
+import org.apache.nifi.annotation.behavior.WritesAttribute;
+import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
+import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
+import software.amazon.awssdk.services.translate.TranslateClient;
+import software.amazon.awssdk.services.translate.TranslateClientBuilder;
+import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
+import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
@CapabilityDescription("Trigger a AWS Translate job. It should be followed by GetAwsTranslateJobStatus processor in order to monitor job status.")
+@WritesAttributes({
+ @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranslateJobStatus")
+})
@SeeAlso({GetAwsTranslateJobStatus.class})
-public class StartAwsTranslateJob extends AwsMachineLearningJobStarter {
+public class StartAwsTranslateJob extends AbstractAwsMachineLearningJobStarter<
+ StartTextTranslationJobRequest, StartTextTranslationJobRequest.Builder, StartTextTranslationJobResponse, TranslateClient, TranslateClientBuilder> {
@Override
- protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
- return (AmazonTranslateClient) AmazonTranslateClient.builder()
- .withRegion(context.getProperty(REGION).getValue())
- .withCredentials(credentialsProvider)
- .withClientConfiguration(config)
- .withEndpointConfiguration(endpointConfiguration)
- .build();
+ protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
+ return TranslateClient.builder();
}
@Override
- protected StartTextTranslationJobResult sendRequest(StartTextTranslationJobRequest request, ProcessContext context, FlowFile flowFile) {
+ protected StartTextTranslationJobResponse sendRequest(final StartTextTranslationJobRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startTextTranslationJob(request);
}
@Override
- protected Class getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
- return StartTextTranslationJobRequest.class;
+ protected Class extends StartTextTranslationJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
+ return StartTextTranslationJobRequest.serializableBuilderClass();
}
- protected String getAwsTaskId(ProcessContext context, StartTextTranslationJobResult startTextTranslationJobResult, FlowFile flowFile) {
- return startTextTranslationJobResult.getJobId();
+ protected String getAwsTaskId(final ProcessContext context, final StartTextTranslationJobResponse startTextTranslationJobResponse, final FlowFile flowFile) {
+ return startTextTranslationJobResponse.jobId();
}
}
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java
index d9b7b51ffc..73af8f04c0 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java
@@ -17,18 +17,10 @@
package org.apache.nifi.processors.aws.ml.polly;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.polly.AmazonPollyClient;
-import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
-import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
-import com.amazonaws.services.polly.model.SynthesisTask;
-import com.amazonaws.services.polly.model.TaskStatus;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
@@ -38,16 +30,20 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.services.polly.PollyClient;
+import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
+import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
+import software.amazon.awssdk.services.polly.model.SynthesisTask;
+import software.amazon.awssdk.services.polly.model.TaskStatus;
import java.util.Collections;
-import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@@ -57,73 +53,84 @@ public class GetAwsPollyStatusTest {
private static final String PLACEHOLDER_CONTENT = "content";
private TestRunner runner;
@Mock
- private AmazonPollyClient mockPollyClient;
- @Mock
- private AWSCredentialsProviderService mockAwsCredentialsProvider;
+ private PollyClient mockPollyClient;
+
+ private GetAwsPollyJobStatus processor;
+
@Captor
private ArgumentCaptor requestCaptor;
+ private TestRunner createRunner(final GetAwsPollyJobStatus processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
+
@BeforeEach
public void setUp() throws InitializationException {
- when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider");
-
- final GetAwsPollyJobStatus mockGetAwsPollyStatus = new GetAwsPollyJobStatus() {
+ processor = new GetAwsPollyJobStatus() {
@Override
- protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
+ public PollyClient getClient(final ProcessContext context) {
return mockPollyClient;
}
};
- runner = TestRunners.newTestRunner(mockGetAwsPollyStatus);
- runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
- runner.enableControllerService(mockAwsCredentialsProvider);
- runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
+ runner = createRunner(processor);
}
@Test
public void testPollyTaskInProgress() {
- GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
- SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
- .withTaskStatus(TaskStatus.InProgress);
- taskResult.setSynthesisTask(task);
- when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
+ GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
+ .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).taskStatus(TaskStatus.IN_PROGRESS).build())
+ .build();
+ when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
+ assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
}
@Test
public void testPollyTaskCompleted() {
- GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
- SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
- .withTaskStatus(TaskStatus.Completed)
- .withOutputUri("outputLocationPath");
- taskResult.setSynthesisTask(task);
- when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
+ final String uri = "https://s3.us-west2.amazonaws.com/bucket/object";
+ final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
+ .synthesisTask(SynthesisTask.builder()
+ .taskId(TEST_TASK_ID)
+ .taskStatus(TaskStatus.COMPLETED)
+ .outputUri(uri).build())
+ .build();
+ when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
runner.assertAllFlowFilesContainAttribute(REL_SUCCESS, AWS_TASK_OUTPUT_LOCATION);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
- }
+ assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
+ final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
+ assertEquals(uri, flowFile.getAttribute(GetAwsPollyJobStatus.AWS_TASK_OUTPUT_LOCATION));
+ assertEquals("bucket", flowFile.getAttribute("PollyS3OutputBucket"));
+ assertEquals("object", flowFile.getAttribute("filename"));
+ }
@Test
public void testPollyTaskFailed() {
- GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
- SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
- .withTaskStatus(TaskStatus.Failed)
- .withTaskStatusReason("reasonOfFailure");
- taskResult.setSynthesisTask(task);
- when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
+ final String failureReason = "reasonOfFailure";
+ final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
+ .synthesisTask(SynthesisTask.builder()
+ .taskId(TEST_TASK_ID)
+ .taskStatus(TaskStatus.FAILED)
+ .taskStatusReason(failureReason).build())
+ .build();
+ when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
+ assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
+
+ final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
+ assertEquals(failureReason, flowFile.getAttribute(GetAwsPollyJobStatus.FAILURE_REASON_ATTRIBUTE));
}
}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJobTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJobTest.java
new file mode 100644
index 0000000000..4e68e2d7a1
--- /dev/null
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJobTest.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.processors.aws.ml.polly;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.MapperFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.json.JsonMapper;
+import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
+import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.TestRunner;
+import org.apache.nifi.util.TestRunners;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.awscore.exception.AwsServiceException;
+import software.amazon.awssdk.services.polly.PollyClient;
+import software.amazon.awssdk.services.polly.model.Engine;
+import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
+import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
+import software.amazon.awssdk.services.polly.model.SynthesisTask;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+public class StartAwsPollyJobTest {
+ private static final String TEST_TASK_ID = "testTaskId";
+ private TestRunner runner;
+ @Mock
+ private PollyClient mockPollyClient;
+
+ private StartAwsPollyJob processor;
+
+ private ObjectMapper objectMapper = JsonMapper.builder()
+ .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
+ .build();
+ @Captor
+ private ArgumentCaptor requestCaptor;
+
+ private TestRunner createRunner(final StartAwsPollyJob processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
+
+ @BeforeEach
+ public void setUp() throws InitializationException {
+ processor = new StartAwsPollyJob() {
+ @Override
+ public PollyClient getClient(ProcessContext context) {
+ return mockPollyClient;
+ }
+ };
+ runner = createRunner(processor);
+ }
+
+ @Test
+ public void testSuccessfulFlowfileContent() throws JsonProcessingException {
+ final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
+ .engine(Engine.NEURAL)
+ .text("Text")
+ .build();
+ final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
+ .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
+ .build();
+ when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
+
+ assertEquals("Text", requestCaptor.getValue().text());
+ assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
+ }
+
+ @Test
+ public void testSuccessfulAttribute() throws JsonProcessingException {
+ final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
+ .engine(Engine.NEURAL)
+ .text("Text")
+ .build();
+ final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
+ .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
+ .build();
+ when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsPollyJob.JSON_PAYLOAD, "${json.payload}");
+ final Map attributes = new HashMap<>();
+ attributes.put("json.payload", requestJson);
+ runner.enqueue("", attributes);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
+
+ assertEquals("Text", requestCaptor.getValue().text());
+ assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
+ }
+
+ @Test
+ public void testInvalidJson() {
+ final String requestJson = "invalid";
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testServiceFailure() throws JsonProcessingException {
+ final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
+ .engine(Engine.NEURAL)
+ .text("Text")
+ .build();
+ when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ private StartSpeechSynthesisTaskResponse deserialize(final String responseData) throws JsonProcessingException {
+ return objectMapper.readValue(responseData, StartSpeechSynthesisTaskResponse.serializableBuilderClass()).build();
+ }
+
+ private String serialize(final StartSpeechSynthesisTaskRequest request) throws JsonProcessingException {
+ return objectMapper.writeValueAsString(request.toBuilder());
+ }
+}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java
index 82fb01cdcf..9cead6ef96 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java
@@ -17,17 +17,10 @@
package org.apache.nifi.processors.aws.ml.textract;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.textract.AmazonTextractClient;
-import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
-import com.amazonaws.services.textract.model.GetDocumentAnalysisResult;
-import com.amazonaws.services.textract.model.JobStatus;
import com.google.common.collect.ImmutableMap;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
+import org.apache.nifi.processor.Relationship;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
@@ -38,13 +31,20 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.services.textract.TextractClient;
+import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisResponse;
+import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
+import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionResponse;
+import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisResponse;
+import software.amazon.awssdk.services.textract.model.JobStatus;
-import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
-import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_THROTTLED;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@@ -53,64 +53,144 @@ public class GetAwsTextractJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
- private AmazonTextractClient mockTextractClient;
- @Mock
- private AWSCredentialsProviderService mockAwsCredentialsProvider;
+ private TextractClient mockTextractClient;
+
+ private GetAwsTextractJobStatus processor;
+
@Captor
- private ArgumentCaptor requestCaptor;
+ private ArgumentCaptor documentAnalysisCaptor;
+ @Captor
+ private ArgumentCaptor expenseAnalysisRequestCaptor;
+ @Captor
+ private ArgumentCaptor documentTextDetectionCaptor;
+
+ private TestRunner createRunner(final GetAwsTextractJobStatus processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
@BeforeEach
public void setUp() throws InitializationException {
- when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider");
- final GetAwsTextractJobStatus awsTextractJobStatusGetter = new GetAwsTextractJobStatus() {
+ processor = new GetAwsTextractJobStatus() {
@Override
- protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
+ public TextractClient getClient(final ProcessContext context) {
return mockTextractClient;
}
};
- runner = TestRunners.newTestRunner(awsTextractJobStatusGetter);
- runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
- runner.enableControllerService(mockAwsCredentialsProvider);
- runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
+ runner = createRunner(processor);
}
@Test
public void testTextractDocAnalysisTaskInProgress() {
- GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
- .withJobStatus(JobStatus.IN_PROGRESS);
- when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
- TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
- runner.run();
-
- runner.assertAllFlowFilesTransferred(REL_RUNNING);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
+ testTextractDocAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
}
@Test
public void testTextractDocAnalysisTaskComplete() {
- GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
- .withJobStatus(JobStatus.SUCCEEDED);
- when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
- TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
- runner.run();
-
- runner.assertAllFlowFilesTransferred(REL_SUCCESS);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
+ testTextractDocAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
}
@Test
public void testTextractDocAnalysisTaskFailed() {
- GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
- .withJobStatus(JobStatus.FAILED);
- when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
- TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.type));
+ testTextractDocAnalysis(JobStatus.FAILED, REL_FAILURE);
+ }
+
+ @Test
+ public void testTextractDocAnalysisTaskPartialSuccess() {
+ testTextractDocAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
+ }
+
+ @Test
+ public void testTextractDocAnalysisTaskUnkownStatus() {
+ testTextractDocAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
+ }
+
+ private void testTextractDocAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
+ final GetDocumentAnalysisResponse response = GetDocumentAnalysisResponse.builder()
+ .jobStatus(jobStatus).build();
+ when(mockTextractClient.getDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
+ runner.enqueue("content", ImmutableMap.of(
+ TASK_ID.getName(), TEST_TASK_ID,
+ StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE, TextractType.DOCUMENT_ANALYSIS.getType()));
runner.run();
- runner.assertAllFlowFilesTransferred(REL_FAILURE);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
+ runner.assertAllFlowFilesTransferred(expectedRelationship);
+ assertEquals(TEST_TASK_ID, documentAnalysisCaptor.getValue().jobId());
+ }
+
+ @Test
+ public void testTextractExpenseAnalysisTaskInProgress() {
+ testTextractExpenseAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
+ }
+
+ @Test
+ public void testTextractExpenseAnalysisTaskComplete() {
+ testTextractExpenseAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
+ }
+
+ @Test
+ public void testTextractExpenseAnalysisTaskFailed() {
+ testTextractExpenseAnalysis(JobStatus.FAILED, REL_FAILURE);
+ }
+
+ @Test
+ public void testTextractExpenseAnalysisTaskPartialSuccess() {
+ testTextractExpenseAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
+ }
+
+ @Test
+ public void testTextractExpenseAnalysisTaskUnkownStatus() {
+ testTextractExpenseAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
+ }
+
+ private void testTextractExpenseAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
+ runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.EXPENSE_ANALYSIS.getType());
+ final GetExpenseAnalysisResponse response = GetExpenseAnalysisResponse.builder()
+ .jobStatus(jobStatus).build();
+ when(mockTextractClient.getExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
+ runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(expectedRelationship);
+ assertEquals(TEST_TASK_ID, expenseAnalysisRequestCaptor.getValue().jobId());
+ }
+
+ @Test
+ public void testTextractDocumentTextDetectionTaskInProgress() {
+ testTextractDocumentTextDetection(JobStatus.IN_PROGRESS, REL_RUNNING);
+ }
+
+ @Test
+ public void testTextractDocumentTextDetectionTaskComplete() {
+ testTextractDocumentTextDetection(JobStatus.SUCCEEDED, REL_SUCCESS);
+ }
+
+ @Test
+ public void testTextractDocumentTextDetectionTaskFailed() {
+ testTextractDocumentTextDetection(JobStatus.FAILED, REL_FAILURE);
+ }
+
+ @Test
+ public void testTextractDocumentTextDetectionTaskPartialSuccess() {
+ testTextractDocumentTextDetection(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
+ }
+
+ @Test
+ public void testTextractDocumentTextDetectionTaskUnkownStatus() {
+ testTextractDocumentTextDetection(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
+ }
+
+ private void testTextractDocumentTextDetection(final JobStatus jobStatus, final Relationship expectedRelationship) {
+ runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.DOCUMENT_TEXT_DETECTION.getType());
+
+ final GetDocumentTextDetectionResponse response = GetDocumentTextDetectionResponse.builder()
+ .jobStatus(jobStatus).build();
+ when(mockTextractClient.getDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
+ runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(expectedRelationship);
+ assertEquals(TEST_TASK_ID, documentTextDetectionCaptor.getValue().jobId());
}
}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJobStatusTest.java
new file mode 100644
index 0000000000..16a4b846df
--- /dev/null
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJobStatusTest.java
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.processors.aws.ml.textract;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.MapperFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.json.JsonMapper;
+import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
+import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.TestRunner;
+import org.apache.nifi.util.TestRunners;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.awscore.exception.AwsServiceException;
+import software.amazon.awssdk.services.textract.TextractClient;
+import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
+import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
+import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
+import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
+import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
+import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_TEXT_DETECTION;
+import static org.apache.nifi.processors.aws.ml.textract.TextractType.EXPENSE_ANALYSIS;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+public class StartAwsTextractJobStatusTest {
+ private static final String TEST_TASK_ID = "testTaskId";
+ private TestRunner runner;
+ @Mock
+ private TextractClient mockTextractClient;
+
+ private StartAwsTextractJob processor;
+
+ private ObjectMapper objectMapper = JsonMapper.builder()
+ .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
+ .build();
+ @Captor
+ private ArgumentCaptor documentAnalysisCaptor;
+ @Captor
+ private ArgumentCaptor expenseAnalysisRequestCaptor;
+ @Captor
+ private ArgumentCaptor documentTextDetectionCaptor;
+
+ private TestRunner createRunner(final StartAwsTextractJob processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
+
+ @BeforeEach
+ public void setUp() throws InitializationException {
+ processor = new StartAwsTextractJob() {
+ @Override
+ public TextractClient getClient(ProcessContext context) {
+ return mockTextractClient;
+ }
+ };
+ runner = createRunner(processor);
+ }
+
+ @Test
+ public void testSuccessfulDocumentAnalysisFlowfileContent() throws JsonProcessingException {
+ final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
+ .jobTag("Tag")
+ .build();
+ final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
+
+ assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testSuccessfulDocumentAnalysisAttribute() throws JsonProcessingException {
+ final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
+ .jobTag("Tag")
+ .build();
+ final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
+ runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
+ final Map attributes = new HashMap<>();
+ attributes.put("json.payload", requestJson);
+ runner.enqueue("", attributes);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
+
+ assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testInvalidDocumentAnalysisJson() {
+ final String requestJson = "invalid";
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testDocumentAnalysisServiceFailure() throws JsonProcessingException {
+ final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
+ .jobTag("Tag")
+ .build();
+ when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testSuccessfulExpenseAnalysisFlowfileContent() throws JsonProcessingException {
+ final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
+ .jobTag("Tag")
+ .build();
+ final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
+
+ assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testSuccessfulExpenseAnalysisAttribute() throws JsonProcessingException {
+ final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
+ .jobTag("Tag")
+ .build();
+ final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
+ runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
+ final Map attributes = new HashMap<>();
+ attributes.put("json.payload", requestJson);
+ runner.enqueue("", attributes);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
+
+ assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testInvalidExpenseAnalysisJson() {
+ final String requestJson = "invalid";
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testExpenseAnalysisServiceFailure() throws JsonProcessingException {
+ final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
+ .jobTag("Tag")
+ .build();
+ when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testSuccessfulDocumentTextDetectionFlowfileContent() throws JsonProcessingException {
+ final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
+ .jobTag("Tag")
+ .build();
+ final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
+
+ assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testSuccessfulDocumentTextDetectionAttribute() throws JsonProcessingException {
+ final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
+ .jobTag("Tag")
+ .build();
+ final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
+ runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
+ final Map attributes = new HashMap<>();
+ attributes.put("json.payload", requestJson);
+ runner.enqueue("", attributes);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
+
+ assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testInvalidDocumentTextDetectionJson() {
+ final String requestJson = "invalid";
+ runner.enqueue(requestJson);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testDocumentTextDetectionServiceFailure() throws JsonProcessingException {
+ final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
+ .jobTag("Tag")
+ .build();
+ when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ private StartDocumentTextDetectionResponse deserializeDTDRequest(final String responseData) throws JsonProcessingException {
+ return objectMapper.readValue(responseData, StartDocumentTextDetectionResponse.serializableBuilderClass()).build();
+ }
+
+ private StartDocumentAnalysisResponse deserializeDARequest(final String responseData) throws JsonProcessingException {
+ return objectMapper.readValue(responseData, StartDocumentAnalysisResponse.serializableBuilderClass()).build();
+ }
+
+ private StartExpenseAnalysisResponse deserializeEARequest(final String responseData) throws JsonProcessingException {
+ return objectMapper.readValue(responseData, StartExpenseAnalysisResponse.serializableBuilderClass()).build();
+ }
+
+ private String serialize(final StartDocumentAnalysisRequest request) throws JsonProcessingException {
+ return objectMapper.writeValueAsString(request.toBuilder());
+ }
+
+ private String serialize(final StartExpenseAnalysisRequest request) throws JsonProcessingException {
+ return objectMapper.writeValueAsString(request.toBuilder());
+ }
+
+ private String serialize(final StartDocumentTextDetectionRequest request) throws JsonProcessingException {
+ return objectMapper.writeValueAsString(request.toBuilder());
+ }
+}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java
index 68a253dc7a..9e139d2c31 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java
@@ -17,19 +17,11 @@
package org.apache.nifi.processors.aws.ml.transcribe;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.transcribe.AmazonTranscribeClient;
-import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
-import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
-import com.amazonaws.services.transcribe.model.Transcript;
-import com.amazonaws.services.transcribe.model.TranscriptionJob;
-import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
+import org.apache.nifi.processor.Relationship;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
@@ -39,93 +31,118 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.services.transcribe.TranscribeClient;
+import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
+import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
+import software.amazon.awssdk.services.transcribe.model.Transcript;
+import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
+import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
import java.util.Collections;
-import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class GetAwsTranscribeJobStatusTest {
- private static final String TEST_TASK_ID = "testTaskId";
- private static final String AWS_CREDENTIAL_PROVIDER_NAME = "awsCredetialProvider";
+ private static final String TEST_TASK_ID = "testJobId";
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
private static final String REASON_OF_FAILURE = "reasonOfFailure";
private static final String CONTENT_STRING = "content";
private TestRunner runner;
@Mock
- private AmazonTranscribeClient mockTranscribeClient;
- @Mock
- private AWSCredentialsProviderService mockAwsCredentialsProvider;
+ private TranscribeClient mockTranscribeClient;
+
+ private GetAwsTranscribeJobStatus processor;
+
@Captor
private ArgumentCaptor requestCaptor;
+ private TestRunner createRunner(final GetAwsTranscribeJobStatus processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
+
@BeforeEach
public void setUp() throws InitializationException {
- when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIAL_PROVIDER_NAME);
- final GetAwsTranscribeJobStatus mockPollyFetcher = new GetAwsTranscribeJobStatus() {
+ processor = new GetAwsTranscribeJobStatus() {
@Override
- protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
+ public TranscribeClient getClient(final ProcessContext context) {
return mockTranscribeClient;
}
};
- runner = TestRunners.newTestRunner(mockPollyFetcher);
- runner.addControllerService(AWS_CREDENTIAL_PROVIDER_NAME, mockAwsCredentialsProvider);
- runner.enableControllerService(mockAwsCredentialsProvider);
- runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIAL_PROVIDER_NAME);
+ runner = createRunner(processor);
}
@Test
- public void testTranscribeTaskInProgress() {
- TranscriptionJob task = new TranscriptionJob()
- .withTranscriptionJobName(TEST_TASK_ID)
- .withTranscriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS);
- GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
- when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
- runner.run();
-
- runner.assertAllFlowFilesTransferred(REL_RUNNING);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
+ public void testTranscribeJobInProgress() {
+ final TranscriptionJob job = TranscriptionJob.builder()
+ .transcriptionJobName(TEST_TASK_ID)
+ .transcriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS)
+ .build();
+ testTranscribeJob(job, REL_RUNNING);
}
@Test
- public void testTranscribeTaskCompleted() {
- TranscriptionJob task = new TranscriptionJob()
- .withTranscriptionJobName(TEST_TASK_ID)
- .withTranscript(new Transcript().withTranscriptFileUri(OUTPUT_LOCATION_PATH))
- .withTranscriptionJobStatus(TranscriptionJobStatus.COMPLETED);
- GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
- when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
- runner.run();
-
- runner.assertAllFlowFilesTransferred(REL_SUCCESS);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
+ public void testTranscribeJobQueued() {
+ final TranscriptionJob job = TranscriptionJob.builder()
+ .transcriptionJobName(TEST_TASK_ID)
+ .transcriptionJobStatus(TranscriptionJobStatus.QUEUED)
+ .build();
+ testTranscribeJob(job, REL_RUNNING);
}
+ @Test
+ public void testTranscribeJobCompleted() {
+ final TranscriptionJob job = TranscriptionJob.builder()
+ .transcriptionJobName(TEST_TASK_ID)
+ .transcript(Transcript.builder().transcriptFileUri(OUTPUT_LOCATION_PATH).build())
+ .transcriptionJobStatus(TranscriptionJobStatus.COMPLETED)
+ .build();
+ testTranscribeJob(job, REL_SUCCESS);
+ runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
+ final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
+ assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
+ }
@Test
- public void testPollyTaskFailed() {
- TranscriptionJob task = new TranscriptionJob()
- .withTranscriptionJobName(TEST_TASK_ID)
- .withFailureReason(REASON_OF_FAILURE)
- .withTranscriptionJobStatus(TranscriptionJobStatus.FAILED);
- GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
- when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
- runner.run();
-
- runner.assertAllFlowFilesTransferred(REL_FAILURE);
+ public void testTranscribeJobFailed() {
+ final TranscriptionJob job = TranscriptionJob.builder()
+ .transcriptionJobName(TEST_TASK_ID)
+ .failureReason(REASON_OF_FAILURE)
+ .transcriptionJobStatus(TranscriptionJobStatus.FAILED)
+ .build();
+ testTranscribeJob(job, REL_FAILURE);
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
+ final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
+ assertEquals(REASON_OF_FAILURE, flowFile.getAttribute(FAILURE_REASON_ATTRIBUTE));
+ }
+ @Test
+ public void testTranscribeJobUnrecognized() {
+ final TranscriptionJob job = TranscriptionJob.builder()
+ .transcriptionJobName(TEST_TASK_ID)
+ .failureReason(REASON_OF_FAILURE)
+ .transcriptionJobStatus(TranscriptionJobStatus.UNKNOWN_TO_SDK_VERSION)
+ .build();
+ testTranscribeJob(job, REL_FAILURE);
+ runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
+ }
+
+ private void testTranscribeJob(final TranscriptionJob job, final Relationship expectedRelationship) {
+ final GetTranscriptionJobResponse response = GetTranscriptionJobResponse.builder().transcriptionJob(job).build();
+ when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(response);
+ runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(expectedRelationship);
+ assertEquals(TEST_TASK_ID, requestCaptor.getValue().transcriptionJobName());
}
}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJobTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJobTest.java
new file mode 100644
index 0000000000..1da8247162
--- /dev/null
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJobTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.processors.aws.ml.transcribe;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.MapperFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.json.JsonMapper;
+import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
+import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.TestRunner;
+import org.apache.nifi.util.TestRunners;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.awscore.exception.AwsServiceException;
+import software.amazon.awssdk.services.transcribe.TranscribeClient;
+import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
+import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
+import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+public class StartAwsTranscribeJobTest {
+ private static final String TEST_TASK_ID = "testTaskId";
+ private TestRunner runner;
+ @Mock
+ private TranscribeClient mockTranscribeClient;
+
+ private StartAwsTranscribeJob processor;
+
+ private ObjectMapper objectMapper = JsonMapper.builder()
+ .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
+ .build();
+ @Captor
+ private ArgumentCaptor requestCaptor;
+
+ private TestRunner createRunner(final StartAwsTranscribeJob processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
+
+ @BeforeEach
+ public void setUp() throws InitializationException {
+ processor = new StartAwsTranscribeJob() {
+ @Override
+ public TranscribeClient getClient(ProcessContext context) {
+ return mockTranscribeClient;
+ }
+ };
+ runner = createRunner(processor);
+ }
+
+ @Test
+ public void testSuccessfulFlowfileContent() throws JsonProcessingException {
+ final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
+ .transcriptionJobName("Job")
+ .build();
+ final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
+ .transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
+ .build();
+ when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
+
+ assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
+ assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
+ }
+
+ @Test
+ public void testSuccessfulAttribute() throws JsonProcessingException {
+ final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
+ .transcriptionJobName("Job")
+ .build();
+ final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
+ .transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
+ .build();
+ when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTranscribeJob.JSON_PAYLOAD, "${json.payload}");
+ final Map attributes = new HashMap<>();
+ attributes.put("json.payload", requestJson);
+ runner.enqueue("", attributes);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
+
+ assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
+ assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
+ }
+
+ @Test
+ public void testInvalidJson() {
+ final String requestJson = "invalid";
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testServiceFailure() throws JsonProcessingException {
+ final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
+ .transcriptionJobName("Job")
+ .build();
+ when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ private StartTranscriptionJobResponse deserialize(final String responseData) throws JsonProcessingException {
+ return objectMapper.readValue(responseData, StartTranscriptionJobResponse.serializableBuilderClass()).build();
+ }
+
+ private String serialize(final StartTranscriptionJobRequest request) throws JsonProcessingException {
+ return objectMapper.writeValueAsString(request.toBuilder());
+ }
+}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java
index 7c868c70ce..dd3b701348 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java
@@ -17,19 +17,11 @@
package org.apache.nifi.processors.aws.ml.translate;
-import com.amazonaws.ClientConfiguration;
-import com.amazonaws.auth.AWSCredentialsProvider;
-import com.amazonaws.client.builder.AwsClientBuilder;
-import com.amazonaws.regions.Region;
-import com.amazonaws.services.translate.AmazonTranslateClient;
-import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
-import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
-import com.amazonaws.services.translate.model.JobStatus;
-import com.amazonaws.services.translate.model.OutputDataConfig;
-import com.amazonaws.services.translate.model.TextTranslationJobProperties;
import org.apache.nifi.processor.ProcessContext;
-import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
+import org.apache.nifi.processor.Relationship;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
@@ -39,91 +31,132 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.services.translate.TranslateClient;
+import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
+import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
+import software.amazon.awssdk.services.translate.model.JobStatus;
+import software.amazon.awssdk.services.translate.model.OutputDataConfig;
+import software.amazon.awssdk.services.translate.model.TextTranslationJobProperties;
+import java.time.Instant;
import java.util.Collections;
-import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
-import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class GetAwsTranslateJobStatusTest {
- private static final String TEST_TASK_ID = "testTaskId";
+ private static final String TEST_TASK_ID = "testJobId";
+ private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
+ private static final String REASON_OF_FAILURE = "reasonOfFailure";
private static final String CONTENT_STRING = "content";
- private static final String AWS_CREDENTIALS_PROVIDER_NAME = "awsCredetialProvider";
- private static final String OUTPUT_LOCATION_PATH = "outputLocation";
private TestRunner runner;
@Mock
- private AmazonTranslateClient mockTranslateClient;
- @Mock
- private AWSCredentialsProviderService mockAwsCredentialsProvider;
+ private TranslateClient mockTranslateClient;
+
+ private GetAwsTranslateJobStatus processor;
+
@Captor
private ArgumentCaptor requestCaptor;
+ private TestRunner createRunner(final GetAwsTranslateJobStatus processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
@BeforeEach
public void setUp() throws InitializationException {
- when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIALS_PROVIDER_NAME);
- final GetAwsTranslateJobStatus mockPollyFetcher = new GetAwsTranslateJobStatus() {
+ processor = new GetAwsTranslateJobStatus() {
@Override
- protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
- final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
+ public TranslateClient getClient(final ProcessContext context) {
return mockTranslateClient;
}
};
- runner = TestRunners.newTestRunner(mockPollyFetcher);
- runner.addControllerService(AWS_CREDENTIALS_PROVIDER_NAME, mockAwsCredentialsProvider);
- runner.enableControllerService(mockAwsCredentialsProvider);
- runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIALS_PROVIDER_NAME);
+ runner = createRunner(processor);
}
@Test
- public void testTranscribeTaskInProgress() {
- TextTranslationJobProperties task = new TextTranslationJobProperties()
- .withJobId(TEST_TASK_ID)
- .withJobStatus(JobStatus.IN_PROGRESS);
- DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
- when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
- runner.run();
-
- runner.assertAllFlowFilesTransferred(REL_RUNNING);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
+ public void testTranslateJobInProgress() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobId(TEST_TASK_ID)
+ .jobStatus(JobStatus.IN_PROGRESS)
+ .build();
+ testTranslateJob(job, REL_RUNNING);
}
@Test
- public void testTranscribeTaskCompleted() {
- TextTranslationJobProperties task = new TextTranslationJobProperties()
- .withJobId(TEST_TASK_ID)
- .withOutputDataConfig(new OutputDataConfig().withS3Uri(OUTPUT_LOCATION_PATH))
- .withJobStatus(JobStatus.COMPLETED);
- DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
- when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
- runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
- runner.run();
+ public void testTranslateSubmitted() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobId(TEST_TASK_ID)
+ .jobStatus(JobStatus.SUBMITTED)
+ .build();
+ testTranslateJob(job, REL_RUNNING);
+ }
- runner.assertAllFlowFilesTransferred(REL_SUCCESS);
+ @Test
+ public void testTranslateStopRequested() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobId(TEST_TASK_ID)
+ .jobStatus(JobStatus.STOP_REQUESTED)
+ .build();
+ testTranslateJob(job, REL_RUNNING);
+ }
+
+ @Test
+ public void testTranslateJobCompleted() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobStatus(TEST_TASK_ID)
+ .outputDataConfig(OutputDataConfig.builder().s3Uri(OUTPUT_LOCATION_PATH).build())
+ .submittedTime(Instant.now())
+ .jobStatus(JobStatus.COMPLETED)
+ .build();
+ testTranslateJob(job, REL_SUCCESS);
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
+ final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
+ assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
}
@Test
- public void testTranscribeTaskFailed() {
- TextTranslationJobProperties task = new TextTranslationJobProperties()
- .withJobId(TEST_TASK_ID)
- .withJobStatus(JobStatus.FAILED);
- DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
- when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
+ public void testTranslateJobFailed() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobStatus(TEST_TASK_ID)
+ .jobStatus(JobStatus.FAILED)
+ .build();
+ testTranslateJob(job, REL_FAILURE);
+ }
+
+ @Test
+ public void testTranslateJobStopped() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobStatus(TEST_TASK_ID)
+ .jobStatus(JobStatus.STOPPED)
+ .build();
+ testTranslateJob(job, REL_FAILURE);
+ }
+
+ @Test
+ public void testTranslateJobUnrecognized() {
+ final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
+ .jobStatus(TEST_TASK_ID)
+ .jobStatus(JobStatus.UNKNOWN_TO_SDK_VERSION)
+ .build();
+ testTranslateJob(job, REL_FAILURE);
+ }
+
+ private void testTranslateJob(final TextTranslationJobProperties job, final Relationship expectedRelationship) {
+ final DescribeTextTranslationJobResponse response = DescribeTextTranslationJobResponse.builder().textTranslationJobProperties(job).build();
+ when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(response);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
- runner.assertAllFlowFilesTransferred(REL_FAILURE);
- assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
+ runner.assertAllFlowFilesTransferred(expectedRelationship);
+ assertEquals(TEST_TASK_ID, requestCaptor.getValue().jobId());
}
}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJobTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJobTest.java
new file mode 100644
index 0000000000..d8e1568c1d
--- /dev/null
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJobTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.nifi.processors.aws.ml.translate;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.MapperFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.json.JsonMapper;
+import org.apache.nifi.processor.ProcessContext;
+import org.apache.nifi.processors.aws.testutil.AuthUtils;
+import org.apache.nifi.reporting.InitializationException;
+import org.apache.nifi.util.TestRunner;
+import org.apache.nifi.util.TestRunners;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import software.amazon.awssdk.awscore.exception.AwsServiceException;
+import software.amazon.awssdk.services.translate.TranslateClient;
+import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
+import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
+import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.when;
+
+@ExtendWith(MockitoExtension.class)
+public class StartAwsTranslateJobTest {
+ private static final String TEST_TASK_ID = "testTaskId";
+ private TestRunner runner;
+ @Mock
+ private TranslateClient mockTranslateClient;
+
+ private StartAwsTranslateJob processor;
+
+ private ObjectMapper objectMapper = JsonMapper.builder()
+ .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
+ .build();
+ @Captor
+ private ArgumentCaptor requestCaptor;
+
+ private TestRunner createRunner(final StartAwsTranslateJob processor) {
+ final TestRunner runner = TestRunners.newTestRunner(processor);
+ AuthUtils.enableAccessKey(runner, "abcd", "defg");
+ return runner;
+ }
+
+ @BeforeEach
+ public void setUp() throws InitializationException {
+ processor = new StartAwsTranslateJob() {
+ @Override
+ public TranslateClient getClient(ProcessContext context) {
+ return mockTranslateClient;
+ }
+ };
+ runner = createRunner(processor);
+ }
+
+ @Test
+ public void testSuccessfulFlowfileContent() throws JsonProcessingException {
+ final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
+ .terminologyNames("Name")
+ .build();
+ final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
+
+ assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testSuccessfulAttribute() throws JsonProcessingException {
+ final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
+ .terminologyNames("Name")
+ .build();
+ final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
+ .jobId(TEST_TASK_ID)
+ .build();
+ when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
+
+ final String requestJson = serialize(request);
+ runner.setProperty(StartAwsTranslateJob.JSON_PAYLOAD, "${json.payload}");
+ final Map attributes = new HashMap<>();
+ attributes.put("json.payload", requestJson);
+ runner.enqueue("", attributes);
+ runner.run();
+
+ runner.assertTransferCount(REL_SUCCESS, 1);
+ runner.assertTransferCount(REL_ORIGINAL, 1);
+ final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
+ final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
+
+ assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
+ assertEquals(TEST_TASK_ID, parsedResponse.jobId());
+ }
+
+ @Test
+ public void testInvalidJson() {
+ final String requestJson = "invalid";
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ @Test
+ public void testServiceFailure() throws JsonProcessingException {
+ final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
+ .terminologyNames("Name")
+ .build();
+ when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
+
+ final String requestJson = serialize(request);
+ runner.enqueue(requestJson);
+ runner.run();
+
+ runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
+ }
+
+ private StartTextTranslationJobResponse deserialize(final String responseData) throws JsonProcessingException {
+ return objectMapper.readValue(responseData, StartTextTranslationJobResponse.serializableBuilderClass()).build();
+ }
+
+ private String serialize(final StartTextTranslationJobRequest request) throws JsonProcessingException {
+ return objectMapper.writeValueAsString(request.toBuilder());
+ }
+}
\ No newline at end of file
diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java
index 4e93560297..c0c7981b79 100644
--- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java
+++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java
@@ -66,7 +66,6 @@ public abstract class AbstractSQSIT {
queueUrl = response.queueUrl();
}
-
@AfterAll
public static void shutdown() {
client.close();