diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml index c40c81ae98..a2ab512cd4 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/pom.xml @@ -65,6 +65,22 @@ software.amazon.awssdk firehose + + software.amazon.awssdk + polly + + + software.amazon.awssdk + textract + + + software.amazon.awssdk + transcribe + + + software.amazon.awssdk + translate + software.amazon.kinesis amazon-kinesis-client @@ -130,6 +146,10 @@ com.github.ben-manes.caffeine caffeine + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + org.bouncycastle diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStarter.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStarter.java similarity index 66% rename from nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStarter.java rename to nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStarter.java index 27f0664370..f1c0ab1fb3 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStarter.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStarter.java @@ -17,14 +17,6 @@ package org.apache.nifi.processors.aws.ml; -import com.amazonaws.AmazonWebServiceClient; -import com.amazonaws.AmazonWebServiceRequest; -import com.amazonaws.AmazonWebServiceResult; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.regions.Regions; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.MapperFeature; import com.fasterxml.jackson.databind.ObjectMapper; @@ -33,12 +25,18 @@ import org.apache.commons.io.IOUtils; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.migration.PropertyConfiguration; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.util.StandardValidators; -import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor; +import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor; +import software.amazon.awssdk.awscore.AwsRequest; +import software.amazon.awssdk.awscore.AwsResponse; +import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; +import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder; +import software.amazon.awssdk.core.SdkClient; import java.io.IOException; import java.io.InputStream; @@ -46,10 +44,15 @@ import java.util.List; import java.util.Set; import static org.apache.nifi.flowfile.attributes.CoreAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID; -public abstract class AwsMachineLearningJobStarter - extends AbstractAWSCredentialsProviderProcessor { +public abstract class AbstractAwsMachineLearningJobStarter< + Q extends AwsRequest, + B extends AwsRequest.Builder, + R extends AwsResponse, + T extends SdkClient, + U extends AwsSyncClientBuilder & AwsClientBuilder> + extends AbstractAwsSyncProcessor { public static final PropertyDescriptor JSON_PAYLOAD = new PropertyDescriptor.Builder() .name("json-payload") .displayName("JSON Payload") @@ -62,18 +65,17 @@ public abstract class AwsMachineLearningJobStarter PROPERTIES = List.of( MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE, REGION, @@ -84,10 +86,9 @@ public abstract class AwsMachineLearningJobStarter relationships = Set.of(REL_ORIGINAL, - REL_SUCCESS, - REL_FAILURE); + private static final Set relationships = Set.of(REL_ORIGINAL, REL_SUCCESS, REL_FAILURE); @Override public Set getRelationships() { @@ -105,14 +106,14 @@ public abstract class AwsMachineLearningJobStarter MAPPER.writeValue(out, response)); + childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response.toBuilder())); return childFlowFile; } @@ -156,7 +152,7 @@ public abstract class AwsMachineLearningJobStarter getAwsRequestClass(ProcessContext context, FlowFile flowFile); + abstract protected Class getAwsRequestBuilderClass(ProcessContext context, FlowFile flowFile); - abstract protected String getAwsTaskId(ProcessContext context, RESPONSE response, FlowFile flowFile); + abstract protected String getAwsTaskId(ProcessContext context, R response, FlowFile flowFile); } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStatusProcessor.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStatusProcessor.java similarity index 69% rename from nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStatusProcessor.java rename to nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStatusProcessor.java index 5a5f8f02e2..8cc555b46a 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/AwsMachineLearningJobStatusProcessor.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-abstract-processors/src/main/java/org/apache/nifi/processors/aws/ml/AbstractAwsMachineLearningJobStatusProcessor.java @@ -17,33 +17,30 @@ package org.apache.nifi.processors.aws.ml; -import com.amazonaws.AmazonWebServiceClient; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.ResponseMetadata; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.http.SdkHttpMetadata; -import com.amazonaws.regions.Region; -import com.amazonaws.regions.Regions; import com.fasterxml.jackson.databind.MapperFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; -import com.fasterxml.jackson.databind.module.SimpleModule; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.migration.PropertyConfiguration; import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.util.StandardValidators; -import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor; +import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor; +import software.amazon.awssdk.awscore.AwsResponse; +import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; +import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder; +import software.amazon.awssdk.core.SdkClient; import java.util.List; import java.util.Set; import static org.apache.nifi.expression.ExpressionLanguageScope.FLOWFILE_ATTRIBUTES; -public abstract class AwsMachineLearningJobStatusProcessor - extends AbstractAWSCredentialsProviderProcessor { +public abstract class AbstractAwsMachineLearningJobStatusProcessor< + T extends SdkClient, + U extends AwsSyncClientBuilder & AwsClientBuilder> + extends AbstractAwsSyncProcessor { public static final String AWS_TASK_OUTPUT_LOCATION = "outputLocation"; public static final PropertyDescriptor MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE = new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE) @@ -81,13 +78,6 @@ public abstract class AwsMachineLearningJobStatusProcessor PROPERTIES = List.of( TASK_ID, @@ -99,18 +89,14 @@ public abstract class AwsMachineLearningJobStatusProcessor getRelationships() { return relationships; @@ -129,13 +115,7 @@ public abstract class AwsMachineLearningJobStatusProcessor MAPPER.writeValue(out, response)); + protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final AwsResponse response) { + return session.write(flowFile, out -> MAPPER.writeValue(out, response.toBuilder())); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml index 627b4f8b7f..c0748af598 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/pom.xml @@ -125,22 +125,6 @@ org.apache.nifi nifi-schema-registry-service-api - - com.amazonaws - aws-java-sdk-translate - - - com.amazonaws - aws-java-sdk-polly - - - com.amazonaws - aws-java-sdk-transcribe - - - com.amazonaws - aws-java-sdk-textract - org.bouncycastle bcprov-jdk18on diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java index 73d6bfa148..e2325d8fc2 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyJobStatus.java @@ -17,15 +17,6 @@ package org.apache.nifi.processors.aws.ml.polly; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.polly.AmazonPollyClient; -import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest; -import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult; -import com.amazonaws.services.polly.model.TaskStatus; -import com.amazonaws.services.textract.model.ThrottlingException; import org.apache.nifi.annotation.behavior.WritesAttribute; import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; @@ -34,9 +25,17 @@ import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; +import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.exception.ProcessException; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor; +import software.amazon.awssdk.services.polly.PollyClient; +import software.amazon.awssdk.services.polly.PollyClientBuilder; +import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest; +import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse; +import software.amazon.awssdk.services.polly.model.TaskStatus; +import java.util.HashSet; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -45,10 +44,10 @@ import java.util.regex.Pattern; @SeeAlso({StartAwsPollyJob.class}) @WritesAttributes({ @WritesAttribute(attribute = "PollyS3OutputBucket", description = "The bucket name where polly output will be located."), - @WritesAttribute(attribute = "PollyS3OutputKey", description = "Object key of polly output."), + @WritesAttribute(attribute = "filename", description = "Object key of polly output."), @WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.") }) -public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor { +public class GetAwsPollyJobStatus extends AbstractAwsMachineLearningJobStatusProcessor { private static final String BUCKET = "bucket"; private static final String KEY = "key"; private static final Pattern S3_PATH = Pattern.compile("https://s3.*amazonaws.com/(?<" + BUCKET + ">[^/]+)/(?<" + KEY + ">.*)"); @@ -56,65 +55,66 @@ public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor getRelationships() { + final Set parentRelationships = new HashSet<>(super.getRelationships()); + parentRelationships.remove(REL_THROTTLED); + return Set.copyOf(parentRelationships); } @Override - public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { + protected PollyClientBuilder createClientBuilder(final ProcessContext context) { + return PollyClient.builder(); + } + + @Override + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { FlowFile flowFile = session.get(); if (flowFile == null) { return; } - GetSpeechSynthesisTaskResult speechSynthesisTask; + final GetSpeechSynthesisTaskResponse speechSynthesisTask; try { speechSynthesisTask = getSynthesisTask(context, flowFile); - } catch (ThrottlingException e) { - getLogger().info("Request Rate Limit exceeded", e); - session.transfer(flowFile, REL_THROTTLED); - return; - } catch (Exception e) { + } catch (final Exception e) { getLogger().warn("Failed to get Polly Job status", e); session.transfer(flowFile, REL_FAILURE); return; } - TaskStatus taskStatus = TaskStatus.fromValue(speechSynthesisTask.getSynthesisTask().getTaskStatus()); + final TaskStatus taskStatus = speechSynthesisTask.synthesisTask().taskStatus(); - if (taskStatus == TaskStatus.InProgress || taskStatus == TaskStatus.Scheduled) { - session.penalize(flowFile); + if (taskStatus == TaskStatus.IN_PROGRESS || taskStatus == TaskStatus.SCHEDULED) { + flowFile = session.penalize(flowFile); session.transfer(flowFile, REL_RUNNING); - } else if (taskStatus == TaskStatus.Completed) { - String outputUri = speechSynthesisTask.getSynthesisTask().getOutputUri(); + } else if (taskStatus == TaskStatus.COMPLETED) { + final String outputUri = speechSynthesisTask.synthesisTask().outputUri(); - Matcher matcher = S3_PATH.matcher(outputUri); + final Matcher matcher = S3_PATH.matcher(outputUri); if (matcher.find()) { - session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET)); - session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY)); + flowFile = session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET)); + flowFile = session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY)); } FlowFile childFlowFile = session.create(flowFile); - writeToFlowFile(session, childFlowFile, speechSynthesisTask); + childFlowFile = writeToFlowFile(session, childFlowFile, speechSynthesisTask); childFlowFile = session.putAttribute(childFlowFile, AWS_TASK_OUTPUT_LOCATION, outputUri); session.transfer(flowFile, REL_ORIGINAL); session.transfer(childFlowFile, REL_SUCCESS); getLogger().info("Amazon Polly Task Completed {}", flowFile); - } else if (taskStatus == TaskStatus.Failed) { - final String failureReason = speechSynthesisTask.getSynthesisTask().getTaskStatusReason(); + } else if (taskStatus == TaskStatus.FAILED) { + final String failureReason = speechSynthesisTask.synthesisTask().taskStatusReason(); flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason); session.transfer(flowFile, REL_FAILURE); getLogger().error("Amazon Polly Task Failed {} Reason [{}]", flowFile, failureReason); + } else if (taskStatus == TaskStatus.UNKNOWN_TO_SDK_VERSION) { + flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status"); + session.transfer(flowFile, REL_FAILURE); + getLogger().error("Amazon Polly Task Failed {} Reason [Unrecognized job status]", flowFile); } } - private GetSpeechSynthesisTaskResult getSynthesisTask(ProcessContext context, FlowFile flowFile) { - String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); - GetSpeechSynthesisTaskRequest request = new GetSpeechSynthesisTaskRequest().withTaskId(taskId); + private GetSpeechSynthesisTaskResponse getSynthesisTask(final ProcessContext context, final FlowFile flowFile) { + final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); + final GetSpeechSynthesisTaskRequest request = GetSpeechSynthesisTaskRequest.builder().taskId(taskId).build(); return getClient(context).getSpeechSynthesisTask(request); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java index 0ab52ac597..8052b92558 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJob.java @@ -17,49 +17,45 @@ package org.apache.nifi.processors.aws.ml.polly; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.polly.AmazonPollyClient; -import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskRequest; -import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskResult; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter; +import software.amazon.awssdk.services.polly.PollyClient; +import software.amazon.awssdk.services.polly.PollyClientBuilder; +import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest; +import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse; @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Polly"}) @CapabilityDescription("Trigger a AWS Polly job. It should be followed by GetAwsPollyJobStatus processor in order to monitor job status.") +@WritesAttributes({ + @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsPollyJobStatus") +}) @SeeAlso({GetAwsPollyJobStatus.class}) -public class StartAwsPollyJob extends AwsMachineLearningJobStarter { +public class StartAwsPollyJob extends AbstractAwsMachineLearningJobStarter< + StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskRequest.Builder, StartSpeechSynthesisTaskResponse, PollyClient, PollyClientBuilder> { @Override - protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { - - return (AmazonPollyClient) AmazonPollyClient.builder() - .withRegion(context.getProperty(REGION).getValue()) - .withCredentials(credentialsProvider) - .withClientConfiguration(config) - .withEndpointConfiguration(endpointConfiguration) - .build(); + protected PollyClientBuilder createClientBuilder(final ProcessContext context) { + return PollyClient.builder(); } @Override - protected StartSpeechSynthesisTaskResult sendRequest(StartSpeechSynthesisTaskRequest request, ProcessContext context, FlowFile flowFile) { + protected StartSpeechSynthesisTaskResponse sendRequest(final StartSpeechSynthesisTaskRequest request, final ProcessContext context, final FlowFile flowFile) { return getClient(context).startSpeechSynthesisTask(request); } @Override - protected Class getAwsRequestClass(ProcessContext context, FlowFile flowFile) { - return StartSpeechSynthesisTaskRequest.class; + protected Class getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) { + return StartSpeechSynthesisTaskRequest.serializableBuilderClass(); } @Override - protected String getAwsTaskId(ProcessContext context, StartSpeechSynthesisTaskResult startSpeechSynthesisTaskResult, FlowFile flowFile) { - return startSpeechSynthesisTaskResult.getSynthesisTask().getTaskId(); + protected String getAwsTaskId(final ProcessContext context, final StartSpeechSynthesisTaskResponse startSpeechSynthesisTaskResponse, final FlowFile flowFile) { + return startSpeechSynthesisTaskResponse.synthesisTask().taskId(); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java index ce48142f0c..2687735217 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatus.java @@ -17,48 +17,61 @@ package org.apache.nifi.processors.aws.ml.textract; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.textract.AmazonTextractClient; -import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest; -import com.amazonaws.services.textract.model.GetDocumentTextDetectionRequest; -import com.amazonaws.services.textract.model.GetExpenseAnalysisRequest; -import com.amazonaws.services.textract.model.JobStatus; -import com.amazonaws.services.textract.model.ThrottlingException; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.components.ValidationContext; +import org.apache.nifi.components.ValidationResult; +import org.apache.nifi.components.Validator; import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.exception.ProcessException; -import org.apache.nifi.processor.util.StandardValidators; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor; +import software.amazon.awssdk.services.textract.TextractClient; +import software.amazon.awssdk.services.textract.TextractClientBuilder; +import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest; +import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest; +import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest; +import software.amazon.awssdk.services.textract.model.JobStatus; +import software.amazon.awssdk.services.textract.model.TextractResponse; +import software.amazon.awssdk.services.textract.model.ThrottlingException; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS; +import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE; @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"}) @CapabilityDescription("Retrieves the current status of an AWS Textract job.") @SeeAlso({StartAwsTextractJob.class}) -public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcessor { +public class GetAwsTextractJobStatus extends AbstractAwsMachineLearningJobStatusProcessor { + + public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() { + @Override + public ValidationResult validate(final String subject, final String value, final ValidationContext context) { + if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) { + return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build(); + } else if (TextractType.TEXTRACT_TYPES.contains(value)) { + return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build(); + } else { + return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build(); + } + } + }; + public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder() .name("textract-type") .displayName("Textract Type") .required(true) .description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"") - .allowableValues(TextractType.TEXTRACT_TYPES) .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) - .defaultValue(DOCUMENT_ANALYSIS.getType()) - .addValidator(StandardValidators.NON_EMPTY_VALIDATOR) + .defaultValue(String.format("${%s}", TEXTRACT_TYPE_ATTRIBUTE)) + .addValidator(TEXTRACT_TYPE_VALIDATOR) .build(); private static final List TEXTRACT_PROPERTIES = Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList())); @@ -69,30 +82,24 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso } @Override - protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { - return (AmazonTextractClient) AmazonTextractClient.builder() - .withRegion(context.getProperty(REGION).getValue()) - .withClientConfiguration(config) - .withEndpointConfiguration(endpointConfiguration) - .withCredentials(credentialsProvider) - .build(); + protected TextractClientBuilder createClientBuilder(final ProcessContext context) { + return TextractClient.builder(); } @Override - public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { FlowFile flowFile = session.get(); if (flowFile == null) { return; } - String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue(); + final String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue(); - String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); + final String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); try { - JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId); + final JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId); if (JobStatus.SUCCEEDED == jobStatus) { - Object task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId); - writeToFlowFile(session, flowFile, task); + final TextractResponse task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId); + flowFile = writeToFlowFile(session, flowFile, task); session.transfer(flowFile, REL_SUCCESS); } else if (JobStatus.IN_PROGRESS == jobStatus) { session.transfer(flowFile, REL_RUNNING); @@ -101,29 +108,31 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso } else if (JobStatus.FAILED == jobStatus) { session.transfer(flowFile, REL_FAILURE); getLogger().error("Amazon Textract Task [{}] Failed", awsTaskId); + } else { + throw new IllegalStateException("Unrecognized job status"); } - } catch (ThrottlingException e) { + } catch (final ThrottlingException e) { getLogger().info("Request Rate Limit exceeded", e); session.transfer(flowFile, REL_THROTTLED); - } catch (Exception e) { + } catch (final Exception e) { getLogger().warn("Failed to get Textract Job status", e); session.transfer(flowFile, REL_FAILURE); } } - private Object getTask(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) { + private TextractResponse getTask(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) { return switch (typeOfTextract) { - case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)); - case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)); - case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)); + case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build()); + case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build()); + case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build()); }; } - private JobStatus getTaskStatus(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) { + private JobStatus getTaskStatus(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) { return switch (typeOfTextract) { - case DOCUMENT_ANALYSIS -> JobStatus.fromValue(client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)).getJobStatus()); - case DOCUMENT_TEXT_DETECTION -> JobStatus.fromValue(client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)).getJobStatus()); - case EXPENSE_ANALYSIS -> JobStatus.fromValue(client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)).getJobStatus()); + case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus(); + case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build()).jobStatus(); + case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus(); }; } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java index 70ff2391a8..301aadc880 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJob.java @@ -17,31 +17,26 @@ package org.apache.nifi.processors.aws.ml.textract; -import com.amazonaws.AmazonWebServiceRequest; -import com.amazonaws.AmazonWebServiceResult; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.textract.AmazonTextractClient; -import com.amazonaws.services.textract.model.StartDocumentAnalysisRequest; -import com.amazonaws.services.textract.model.StartDocumentAnalysisResult; -import com.amazonaws.services.textract.model.StartDocumentTextDetectionRequest; -import com.amazonaws.services.textract.model.StartDocumentTextDetectionResult; -import com.amazonaws.services.textract.model.StartExpenseAnalysisRequest; -import com.amazonaws.services.textract.model.StartExpenseAnalysisResult; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.components.PropertyDescriptor; -import org.apache.nifi.components.ValidationContext; -import org.apache.nifi.components.ValidationResult; -import org.apache.nifi.components.Validator; -import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter; +import software.amazon.awssdk.services.textract.TextractClient; +import software.amazon.awssdk.services.textract.TextractClientBuilder; +import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest; +import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse; +import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest; +import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse; +import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest; +import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse; +import software.amazon.awssdk.services.textract.model.TextractRequest; +import software.amazon.awssdk.services.textract.model.TextractResponse; import java.util.Collections; import java.util.List; @@ -52,28 +47,23 @@ import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_A @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"}) @CapabilityDescription("Trigger a AWS Textract job. It should be followed by GetAwsTextractJobStatus processor in order to monitor job status.") +@WritesAttributes({ + @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTextractJobStatus"), + @WritesAttribute(attribute = "awsTextractType", description = "The selected Textract type, which can be used in GetAwsTextractJobStatus") +}) @SeeAlso({GetAwsTextractJobStatus.class}) -public class StartAwsTextractJob extends AwsMachineLearningJobStarter { - public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() { - @Override - public ValidationResult validate(final String subject, final String value, final ValidationContext context) { - if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) { - return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build(); - } else if (TextractType.TEXTRACT_TYPES.contains(value)) { - return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build(); - } else { - return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build(); - } - } - }; +public class StartAwsTextractJob extends AbstractAwsMachineLearningJobStarter< + TextractRequest, TextractRequest.Builder, TextractResponse, TextractClient, TextractClientBuilder> { + + public static final String TEXTRACT_TYPE_ATTRIBUTE = "awsTextractType"; + public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder() .name("textract-type") .displayName("Textract Type") .required(true) .description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"") - .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) - .defaultValue(DOCUMENT_ANALYSIS.type) - .addValidator(TEXTRACT_TYPE_VALIDATOR) + .allowableValues(TextractType.TEXTRACT_TYPES) + .defaultValue(DOCUMENT_ANALYSIS.getType()) .build(); private static final List TEXTRACT_PROPERTIES = Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList())); @@ -84,24 +74,13 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter getClient(context).startDocumentAnalysis((StartDocumentAnalysisRequest) request); case DOCUMENT_TEXT_DETECTION -> getClient(context).startDocumentTextDetection((StartDocumentTextDetectionRequest) request); @@ -110,22 +89,28 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter getAwsRequestClass(ProcessContext context, FlowFile flowFile) { - final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue()); + protected Class getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) { + final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue()); return switch (typeOfTextract) { - case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.class; - case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.class; - case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.class; + case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.serializableBuilderClass(); + case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.serializableBuilderClass(); + case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.serializableBuilderClass(); }; } @Override - protected String getAwsTaskId(ProcessContext context, AmazonWebServiceResult amazonWebServiceResult, FlowFile flowFile) { - final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue()); + protected String getAwsTaskId(final ProcessContext context, final TextractResponse textractResponse, final FlowFile flowFile) { + final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue()); return switch (textractType) { - case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResult) amazonWebServiceResult).getJobId(); - case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResult) amazonWebServiceResult).getJobId(); - case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResult) amazonWebServiceResult).getJobId(); + case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResponse) textractResponse).jobId(); + case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResponse) textractResponse).jobId(); + case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResponse) textractResponse).jobId(); }; } + + @Override + protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, FlowFile flowFile, final TextractResponse response) { + flowFile = super.postProcessFlowFile(context, session, flowFile, response); + return session.putAttribute(flowFile, TEXTRACT_TYPE_ATTRIBUTE, context.getProperty(TEXTRACT_TYPE).getValue()); + } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java index ff48af3843..9ad549b75e 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatus.java @@ -17,15 +17,6 @@ package org.apache.nifi.processors.aws.ml.transcribe; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.textract.model.ThrottlingException; -import com.amazonaws.services.transcribe.AmazonTranscribeClient; -import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest; -import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult; -import com.amazonaws.services.transcribe.model.TranscriptionJobStatus; import org.apache.nifi.annotation.behavior.WritesAttribute; import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; @@ -35,7 +26,13 @@ import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.exception.ProcessException; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor; +import software.amazon.awssdk.services.transcribe.TranscribeClient; +import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder; +import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest; +import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse; +import software.amazon.awssdk.services.transcribe.model.LimitExceededException; +import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus; @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"}) @CapabilityDescription("Retrieves the current status of an AWS Transcribe job.") @@ -43,55 +40,50 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; @WritesAttributes({ @WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.") }) -public class GetAwsTranscribeJobStatus extends AwsMachineLearningJobStatusProcessor { +public class GetAwsTranscribeJobStatus extends AbstractAwsMachineLearningJobStatusProcessor { @Override - protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { - return (AmazonTranscribeClient) AmazonTranscribeClient.builder() - .withRegion(context.getProperty(REGION).getValue()) - .withCredentials(credentialsProvider) - .withEndpointConfiguration(endpointConfiguration) - .withClientConfiguration(config) - .build(); + protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) { + return TranscribeClient.builder(); } @Override - public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { FlowFile flowFile = session.get(); if (flowFile == null) { return; } try { - GetTranscriptionJobResult job = getJob(context, flowFile); - TranscriptionJobStatus jobStatus = TranscriptionJobStatus.fromValue(job.getTranscriptionJob().getTranscriptionJobStatus()); + final GetTranscriptionJobResponse job = getJob(context, flowFile); + final TranscriptionJobStatus status = job.transcriptionJob().transcriptionJobStatus(); - if (TranscriptionJobStatus.COMPLETED == jobStatus) { - writeToFlowFile(session, flowFile, job); - session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.getTranscriptionJob().getTranscript().getTranscriptFileUri()); + if (TranscriptionJobStatus.COMPLETED == status) { + flowFile = writeToFlowFile(session, flowFile, job); + flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.transcriptionJob().transcript().transcriptFileUri()); session.transfer(flowFile, REL_SUCCESS); - } else if (TranscriptionJobStatus.IN_PROGRESS == jobStatus) { + } else if (TranscriptionJobStatus.IN_PROGRESS == status || TranscriptionJobStatus.QUEUED == status) { session.transfer(flowFile, REL_RUNNING); - } else if (TranscriptionJobStatus.FAILED == jobStatus) { - final String failureReason = job.getTranscriptionJob().getFailureReason(); - session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason); + } else if (TranscriptionJobStatus.FAILED == status) { + final String failureReason = job.transcriptionJob().failureReason(); + flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason); session.transfer(flowFile, REL_FAILURE); getLogger().error("Transcribe Task Failed {} Reason [{}]", flowFile, failureReason); + } else { + flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status"); + throw new IllegalStateException("Unrecognized job status"); } - } catch (ThrottlingException e) { + } catch (final LimitExceededException e) { getLogger().info("Request Rate Limit exceeded", e); session.transfer(flowFile, REL_THROTTLED); - return; - } catch (Exception e) { + } catch (final Exception e) { getLogger().warn("Failed to get Transcribe Job status", e); session.transfer(flowFile, REL_FAILURE); - return; } } - private GetTranscriptionJobResult getJob(ProcessContext context, FlowFile flowFile) { - String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); - GetTranscriptionJobRequest request = new GetTranscriptionJobRequest().withTranscriptionJobName(taskId); + private GetTranscriptionJobResponse getJob(final ProcessContext context, final FlowFile flowFile) { + final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); + final GetTranscriptionJobRequest request = GetTranscriptionJobRequest.builder().transcriptionJobName(taskId).build(); return getClient(context).getTranscriptionJob(request); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java index a91192e5c6..5af6423224 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJob.java @@ -17,48 +17,45 @@ package org.apache.nifi.processors.aws.ml.transcribe; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.transcribe.AmazonTranscribeClient; -import com.amazonaws.services.transcribe.model.StartTranscriptionJobRequest; -import com.amazonaws.services.transcribe.model.StartTranscriptionJobResult; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter; +import software.amazon.awssdk.services.transcribe.TranscribeClient; +import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder; +import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest; +import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse; @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"}) @CapabilityDescription("Trigger a AWS Transcribe job. It should be followed by GetAwsTranscribeStatus processor in order to monitor job status.") +@WritesAttributes({ + @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranscribeJobStatus") +}) @SeeAlso({GetAwsTranscribeJobStatus.class}) -public class StartAwsTranscribeJob extends AwsMachineLearningJobStarter { +public class StartAwsTranscribeJob extends AbstractAwsMachineLearningJobStarter< + StartTranscriptionJobRequest, StartTranscriptionJobRequest.Builder, StartTranscriptionJobResponse, TranscribeClient, TranscribeClientBuilder> { @Override - protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { - return (AmazonTranscribeClient) AmazonTranscribeClient.builder() - .withRegion(context.getProperty(REGION).getValue()) - .withClientConfiguration(config) - .withEndpointConfiguration(endpointConfiguration) - .withCredentials(credentialsProvider) - .build(); + protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) { + return TranscribeClient.builder(); } @Override - protected StartTranscriptionJobResult sendRequest(StartTranscriptionJobRequest request, ProcessContext context, FlowFile flowFile) { + protected StartTranscriptionJobResponse sendRequest(final StartTranscriptionJobRequest request, final ProcessContext context, final FlowFile flowFile) { return getClient(context).startTranscriptionJob(request); } @Override - protected Class getAwsRequestClass(ProcessContext context, FlowFile flowFile) { - return StartTranscriptionJobRequest.class; + protected Class getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) { + return StartTranscriptionJobRequest.serializableBuilderClass(); } @Override - protected String getAwsTaskId(ProcessContext context, StartTranscriptionJobResult startTranscriptionJobResult, FlowFile flowFile) { - return startTranscriptionJobResult.getTranscriptionJob().getTranscriptionJobName(); + protected String getAwsTaskId(final ProcessContext context, final StartTranscriptionJobResponse startTranscriptionJobResponse, final FlowFile flowFile) { + return startTranscriptionJobResponse.transcriptionJob().transcriptionJobName(); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java index bc52a23efd..18c8dbaeea 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatus.java @@ -17,15 +17,6 @@ package org.apache.nifi.processors.aws.ml.translate; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.textract.model.ThrottlingException; -import com.amazonaws.services.translate.AmazonTranslateClient; -import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest; -import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult; -import com.amazonaws.services.translate.model.JobStatus; import org.apache.nifi.annotation.behavior.WritesAttribute; import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; @@ -34,8 +25,15 @@ import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; +import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.exception.ProcessException; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor; +import software.amazon.awssdk.services.translate.TranslateClient; +import software.amazon.awssdk.services.translate.TranslateClientBuilder; +import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest; +import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse; +import software.amazon.awssdk.services.translate.model.JobStatus; +import software.amazon.awssdk.services.translate.model.LimitExceededException; @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"}) @CapabilityDescription("Retrieves the current status of an AWS Translate job.") @@ -43,54 +41,66 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; @WritesAttributes({ @WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.") }) -public class GetAwsTranslateJobStatus extends AwsMachineLearningJobStatusProcessor { +public class GetAwsTranslateJobStatus extends AbstractAwsMachineLearningJobStatusProcessor { @Override - protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { - return (AmazonTranslateClient) AmazonTranslateClient.builder() - .withRegion(context.getProperty(REGION).getValue()) - .withCredentials(credentialsProvider) - .withClientConfiguration(config) - .withEndpointConfiguration(endpointConfiguration) - .build(); + protected TranslateClientBuilder createClientBuilder(final ProcessContext context) { + return TranslateClient.builder(); } @Override - public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { FlowFile flowFile = session.get(); if (flowFile == null) { return; } - String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); try { - DescribeTextTranslationJobResult describeTextTranslationJobResult = getStatusString(context, awsTaskId); - JobStatus status = JobStatus.fromValue(describeTextTranslationJobResult.getTextTranslationJobProperties().getJobStatus()); + final DescribeTextTranslationJobResponse job = getJob(context, flowFile); + final JobStatus status = job.textTranslationJobProperties().jobStatus(); - if (status == JobStatus.IN_PROGRESS || status == JobStatus.SUBMITTED) { - writeToFlowFile(session, flowFile, describeTextTranslationJobResult); - session.penalize(flowFile); - session.transfer(flowFile, REL_RUNNING); - } else if (status == JobStatus.COMPLETED) { - session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, describeTextTranslationJobResult.getTextTranslationJobProperties().getOutputDataConfig().getS3Uri()); - writeToFlowFile(session, flowFile, describeTextTranslationJobResult); - session.transfer(flowFile, REL_SUCCESS); - } else if (status == JobStatus.FAILED || status == JobStatus.COMPLETED_WITH_ERROR) { - writeToFlowFile(session, flowFile, describeTextTranslationJobResult); - session.transfer(flowFile, REL_FAILURE); + flowFile = writeToFlowFile(session, flowFile, job); + final Relationship transferRelationship; + String failureReason = null; + switch (status) { + case IN_PROGRESS: + case SUBMITTED: + case STOP_REQUESTED: + flowFile = session.penalize(flowFile); + transferRelationship = REL_RUNNING; + break; + case COMPLETED: + flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.textTranslationJobProperties().outputDataConfig().s3Uri()); + transferRelationship = REL_SUCCESS; + break; + case FAILED: + case COMPLETED_WITH_ERROR: + failureReason = job.textTranslationJobProperties().message(); + transferRelationship = REL_FAILURE; + break; + case STOPPED: + failureReason = String.format("Job [%s] is stopped", job.textTranslationJobProperties().jobId()); + transferRelationship = REL_FAILURE; + break; + default: + failureReason = "Unknown Job Status"; + transferRelationship = REL_FAILURE; } - } catch (ThrottlingException e) { + if (failureReason != null) { + flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason); + } + session.transfer(flowFile, transferRelationship); + } catch (final LimitExceededException e) { getLogger().info("Request Rate Limit exceeded", e); session.transfer(flowFile, REL_THROTTLED); - } catch (Exception e) { - getLogger().warn("Failed to get Polly Job status", e); + } catch (final Exception e) { + getLogger().warn("Failed to get Translate Job status", e); session.transfer(flowFile, REL_FAILURE); } } - private DescribeTextTranslationJobResult getStatusString(ProcessContext context, String awsTaskId) { - DescribeTextTranslationJobRequest request = new DescribeTextTranslationJobRequest().withJobId(awsTaskId); - DescribeTextTranslationJobResult translationJobsResult = getClient(context).describeTextTranslationJob(request); - return translationJobsResult; + private DescribeTextTranslationJobResponse getJob(final ProcessContext context, final FlowFile flowFile) { + final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); + final DescribeTextTranslationJobRequest request = DescribeTextTranslationJobRequest.builder().jobId(taskId).build(); + return getClient(context).describeTextTranslationJob(request); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java index 3e7f079df9..db75233b09 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/main/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJob.java @@ -17,47 +17,44 @@ package org.apache.nifi.processors.aws.ml.translate; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.translate.AmazonTranslateClient; -import com.amazonaws.services.translate.model.StartTextTranslationJobRequest; -import com.amazonaws.services.translate.model.StartTextTranslationJobResult; +import org.apache.nifi.annotation.behavior.WritesAttribute; +import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; +import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter; +import software.amazon.awssdk.services.translate.TranslateClient; +import software.amazon.awssdk.services.translate.TranslateClientBuilder; +import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest; +import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse; @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"}) @CapabilityDescription("Trigger a AWS Translate job. It should be followed by GetAwsTranslateJobStatus processor in order to monitor job status.") +@WritesAttributes({ + @WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranslateJobStatus") +}) @SeeAlso({GetAwsTranslateJobStatus.class}) -public class StartAwsTranslateJob extends AwsMachineLearningJobStarter { +public class StartAwsTranslateJob extends AbstractAwsMachineLearningJobStarter< + StartTextTranslationJobRequest, StartTextTranslationJobRequest.Builder, StartTextTranslationJobResponse, TranslateClient, TranslateClientBuilder> { @Override - protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { - return (AmazonTranslateClient) AmazonTranslateClient.builder() - .withRegion(context.getProperty(REGION).getValue()) - .withCredentials(credentialsProvider) - .withClientConfiguration(config) - .withEndpointConfiguration(endpointConfiguration) - .build(); + protected TranslateClientBuilder createClientBuilder(final ProcessContext context) { + return TranslateClient.builder(); } @Override - protected StartTextTranslationJobResult sendRequest(StartTextTranslationJobRequest request, ProcessContext context, FlowFile flowFile) { + protected StartTextTranslationJobResponse sendRequest(final StartTextTranslationJobRequest request, final ProcessContext context, final FlowFile flowFile) { return getClient(context).startTextTranslationJob(request); } @Override - protected Class getAwsRequestClass(ProcessContext context, FlowFile flowFile) { - return StartTextTranslationJobRequest.class; + protected Class getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) { + return StartTextTranslationJobRequest.serializableBuilderClass(); } - protected String getAwsTaskId(ProcessContext context, StartTextTranslationJobResult startTextTranslationJobResult, FlowFile flowFile) { - return startTextTranslationJobResult.getJobId(); + protected String getAwsTaskId(final ProcessContext context, final StartTextTranslationJobResponse startTextTranslationJobResponse, final FlowFile flowFile) { + return startTextTranslationJobResponse.jobId(); } } diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java index d9b7b51ffc..73af8f04c0 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/GetAwsPollyStatusTest.java @@ -17,18 +17,10 @@ package org.apache.nifi.processors.aws.ml.polly; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.polly.AmazonPollyClient; -import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest; -import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult; -import com.amazonaws.services.polly.model.SynthesisTask; -import com.amazonaws.services.polly.model.TaskStatus; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; +import org.apache.nifi.processors.aws.testutil.AuthUtils; import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.junit.jupiter.api.BeforeEach; @@ -38,16 +30,20 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.services.polly.PollyClient; +import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest; +import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse; +import software.amazon.awssdk.services.polly.model.SynthesisTask; +import software.amazon.awssdk.services.polly.model.TaskStatus; import java.util.Collections; -import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_ORIGINAL; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; @@ -57,73 +53,84 @@ public class GetAwsPollyStatusTest { private static final String PLACEHOLDER_CONTENT = "content"; private TestRunner runner; @Mock - private AmazonPollyClient mockPollyClient; - @Mock - private AWSCredentialsProviderService mockAwsCredentialsProvider; + private PollyClient mockPollyClient; + + private GetAwsPollyJobStatus processor; + @Captor private ArgumentCaptor requestCaptor; + private TestRunner createRunner(final GetAwsPollyJobStatus processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } + @BeforeEach public void setUp() throws InitializationException { - when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider"); - - final GetAwsPollyJobStatus mockGetAwsPollyStatus = new GetAwsPollyJobStatus() { + processor = new GetAwsPollyJobStatus() { @Override - protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { + public PollyClient getClient(final ProcessContext context) { return mockPollyClient; } }; - runner = TestRunners.newTestRunner(mockGetAwsPollyStatus); - runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider); - runner.enableControllerService(mockAwsCredentialsProvider); - runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider"); + runner = createRunner(processor); } @Test public void testPollyTaskInProgress() { - GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult(); - SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID) - .withTaskStatus(TaskStatus.InProgress); - taskResult.setSynthesisTask(task); - when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult); + GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder() + .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).taskStatus(TaskStatus.IN_PROGRESS).build()) + .build(); + when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response); runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.run(); runner.assertAllFlowFilesTransferred(REL_RUNNING); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId()); + assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId()); } @Test public void testPollyTaskCompleted() { - GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult(); - SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID) - .withTaskStatus(TaskStatus.Completed) - .withOutputUri("outputLocationPath"); - taskResult.setSynthesisTask(task); - when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult); + final String uri = "https://s3.us-west2.amazonaws.com/bucket/object"; + final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder() + .synthesisTask(SynthesisTask.builder() + .taskId(TEST_TASK_ID) + .taskStatus(TaskStatus.COMPLETED) + .outputUri(uri).build()) + .build(); + when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response); runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.run(); runner.assertTransferCount(REL_SUCCESS, 1); runner.assertTransferCount(REL_ORIGINAL, 1); runner.assertAllFlowFilesContainAttribute(REL_SUCCESS, AWS_TASK_OUTPUT_LOCATION); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId()); - } + assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId()); + final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next(); + assertEquals(uri, flowFile.getAttribute(GetAwsPollyJobStatus.AWS_TASK_OUTPUT_LOCATION)); + assertEquals("bucket", flowFile.getAttribute("PollyS3OutputBucket")); + assertEquals("object", flowFile.getAttribute("filename")); + } @Test public void testPollyTaskFailed() { - GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult(); - SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID) - .withTaskStatus(TaskStatus.Failed) - .withTaskStatusReason("reasonOfFailure"); - taskResult.setSynthesisTask(task); - when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult); + final String failureReason = "reasonOfFailure"; + final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder() + .synthesisTask(SynthesisTask.builder() + .taskId(TEST_TASK_ID) + .taskStatus(TaskStatus.FAILED) + .taskStatusReason(failureReason).build()) + .build(); + when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response); runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.run(); runner.assertAllFlowFilesTransferred(REL_FAILURE); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId()); + assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId()); + + final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next(); + assertEquals(failureReason, flowFile.getAttribute(GetAwsPollyJobStatus.FAILURE_REASON_ATTRIBUTE)); } } \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJobTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJobTest.java new file mode 100644 index 0000000000..4e68e2d7a1 --- /dev/null +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/polly/StartAwsPollyJobTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.processors.aws.ml.polly; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processors.aws.testutil.AuthUtils; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.polly.PollyClient; +import software.amazon.awssdk.services.polly.model.Engine; +import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest; +import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse; +import software.amazon.awssdk.services.polly.model.SynthesisTask; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class StartAwsPollyJobTest { + private static final String TEST_TASK_ID = "testTaskId"; + private TestRunner runner; + @Mock + private PollyClient mockPollyClient; + + private StartAwsPollyJob processor; + + private ObjectMapper objectMapper = JsonMapper.builder() + .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true) + .build(); + @Captor + private ArgumentCaptor requestCaptor; + + private TestRunner createRunner(final StartAwsPollyJob processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } + + @BeforeEach + public void setUp() throws InitializationException { + processor = new StartAwsPollyJob() { + @Override + public PollyClient getClient(ProcessContext context) { + return mockPollyClient; + } + }; + runner = createRunner(processor); + } + + @Test + public void testSuccessfulFlowfileContent() throws JsonProcessingException { + final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder() + .engine(Engine.NEURAL) + .text("Text") + .build(); + final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder() + .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build()) + .build(); + when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData); + + assertEquals("Text", requestCaptor.getValue().text()); + assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId()); + } + + @Test + public void testSuccessfulAttribute() throws JsonProcessingException { + final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder() + .engine(Engine.NEURAL) + .text("Text") + .build(); + final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder() + .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build()) + .build(); + when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsPollyJob.JSON_PAYLOAD, "${json.payload}"); + final Map attributes = new HashMap<>(); + attributes.put("json.payload", requestJson); + runner.enqueue("", attributes); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData); + + assertEquals("Text", requestCaptor.getValue().text()); + assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId()); + } + + @Test + public void testInvalidJson() { + final String requestJson = "invalid"; + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testServiceFailure() throws JsonProcessingException { + final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder() + .engine(Engine.NEURAL) + .text("Text") + .build(); + when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build()); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + private StartSpeechSynthesisTaskResponse deserialize(final String responseData) throws JsonProcessingException { + return objectMapper.readValue(responseData, StartSpeechSynthesisTaskResponse.serializableBuilderClass()).build(); + } + + private String serialize(final StartSpeechSynthesisTaskRequest request) throws JsonProcessingException { + return objectMapper.writeValueAsString(request.toBuilder()); + } +} \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java index 82fb01cdcf..9cead6ef96 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/GetAwsTextractJobStatusTest.java @@ -17,17 +17,10 @@ package org.apache.nifi.processors.aws.ml.textract; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.textract.AmazonTextractClient; -import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest; -import com.amazonaws.services.textract.model.GetDocumentAnalysisResult; -import com.amazonaws.services.textract.model.JobStatus; import com.google.common.collect.ImmutableMap; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; +import org.apache.nifi.processor.Relationship; +import org.apache.nifi.processors.aws.testutil.AuthUtils; import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; @@ -38,13 +31,20 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.services.textract.TextractClient; +import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest; +import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisResponse; +import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest; +import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionResponse; +import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest; +import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisResponse; +import software.amazon.awssdk.services.textract.model.JobStatus; -import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; -import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_THROTTLED; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; @@ -53,64 +53,144 @@ public class GetAwsTextractJobStatusTest { private static final String TEST_TASK_ID = "testTaskId"; private TestRunner runner; @Mock - private AmazonTextractClient mockTextractClient; - @Mock - private AWSCredentialsProviderService mockAwsCredentialsProvider; + private TextractClient mockTextractClient; + + private GetAwsTextractJobStatus processor; + @Captor - private ArgumentCaptor requestCaptor; + private ArgumentCaptor documentAnalysisCaptor; + @Captor + private ArgumentCaptor expenseAnalysisRequestCaptor; + @Captor + private ArgumentCaptor documentTextDetectionCaptor; + + private TestRunner createRunner(final GetAwsTextractJobStatus processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } @BeforeEach public void setUp() throws InitializationException { - when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider"); - final GetAwsTextractJobStatus awsTextractJobStatusGetter = new GetAwsTextractJobStatus() { + processor = new GetAwsTextractJobStatus() { @Override - protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { + public TextractClient getClient(final ProcessContext context) { return mockTextractClient; } }; - runner = TestRunners.newTestRunner(awsTextractJobStatusGetter); - runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider); - runner.enableControllerService(mockAwsCredentialsProvider); - runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider"); + runner = createRunner(processor); } @Test public void testTextractDocAnalysisTaskInProgress() { - GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult() - .withJobStatus(JobStatus.IN_PROGRESS); - when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID, - TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name())); - runner.run(); - - runner.assertAllFlowFilesTransferred(REL_RUNNING); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); + testTextractDocAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING); } @Test public void testTextractDocAnalysisTaskComplete() { - GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult() - .withJobStatus(JobStatus.SUCCEEDED); - when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID, - TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name())); - runner.run(); - - runner.assertAllFlowFilesTransferred(REL_SUCCESS); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); + testTextractDocAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS); } @Test public void testTextractDocAnalysisTaskFailed() { - GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult() - .withJobStatus(JobStatus.FAILED); - when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID, - TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.type)); + testTextractDocAnalysis(JobStatus.FAILED, REL_FAILURE); + } + + @Test + public void testTextractDocAnalysisTaskPartialSuccess() { + testTextractDocAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED); + } + + @Test + public void testTextractDocAnalysisTaskUnkownStatus() { + testTextractDocAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE); + } + + private void testTextractDocAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) { + final GetDocumentAnalysisResponse response = GetDocumentAnalysisResponse.builder() + .jobStatus(jobStatus).build(); + when(mockTextractClient.getDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response); + runner.enqueue("content", ImmutableMap.of( + TASK_ID.getName(), TEST_TASK_ID, + StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE, TextractType.DOCUMENT_ANALYSIS.getType())); runner.run(); - runner.assertAllFlowFilesTransferred(REL_FAILURE); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); + runner.assertAllFlowFilesTransferred(expectedRelationship); + assertEquals(TEST_TASK_ID, documentAnalysisCaptor.getValue().jobId()); + } + + @Test + public void testTextractExpenseAnalysisTaskInProgress() { + testTextractExpenseAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING); + } + + @Test + public void testTextractExpenseAnalysisTaskComplete() { + testTextractExpenseAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS); + } + + @Test + public void testTextractExpenseAnalysisTaskFailed() { + testTextractExpenseAnalysis(JobStatus.FAILED, REL_FAILURE); + } + + @Test + public void testTextractExpenseAnalysisTaskPartialSuccess() { + testTextractExpenseAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED); + } + + @Test + public void testTextractExpenseAnalysisTaskUnkownStatus() { + testTextractExpenseAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE); + } + + private void testTextractExpenseAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) { + runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.EXPENSE_ANALYSIS.getType()); + final GetExpenseAnalysisResponse response = GetExpenseAnalysisResponse.builder() + .jobStatus(jobStatus).build(); + when(mockTextractClient.getExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response); + runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID)); + runner.run(); + + runner.assertAllFlowFilesTransferred(expectedRelationship); + assertEquals(TEST_TASK_ID, expenseAnalysisRequestCaptor.getValue().jobId()); + } + + @Test + public void testTextractDocumentTextDetectionTaskInProgress() { + testTextractDocumentTextDetection(JobStatus.IN_PROGRESS, REL_RUNNING); + } + + @Test + public void testTextractDocumentTextDetectionTaskComplete() { + testTextractDocumentTextDetection(JobStatus.SUCCEEDED, REL_SUCCESS); + } + + @Test + public void testTextractDocumentTextDetectionTaskFailed() { + testTextractDocumentTextDetection(JobStatus.FAILED, REL_FAILURE); + } + + @Test + public void testTextractDocumentTextDetectionTaskPartialSuccess() { + testTextractDocumentTextDetection(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED); + } + + @Test + public void testTextractDocumentTextDetectionTaskUnkownStatus() { + testTextractDocumentTextDetection(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE); + } + + private void testTextractDocumentTextDetection(final JobStatus jobStatus, final Relationship expectedRelationship) { + runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.DOCUMENT_TEXT_DETECTION.getType()); + + final GetDocumentTextDetectionResponse response = GetDocumentTextDetectionResponse.builder() + .jobStatus(jobStatus).build(); + when(mockTextractClient.getDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response); + runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID)); + runner.run(); + + runner.assertAllFlowFilesTransferred(expectedRelationship); + assertEquals(TEST_TASK_ID, documentTextDetectionCaptor.getValue().jobId()); } } \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJobStatusTest.java new file mode 100644 index 0000000000..16a4b846df --- /dev/null +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/textract/StartAwsTextractJobStatusTest.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.processors.aws.ml.textract; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processors.aws.testutil.AuthUtils; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.textract.TextractClient; +import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest; +import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse; +import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest; +import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse; +import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest; +import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS; +import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_TEXT_DETECTION; +import static org.apache.nifi.processors.aws.ml.textract.TextractType.EXPENSE_ANALYSIS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class StartAwsTextractJobStatusTest { + private static final String TEST_TASK_ID = "testTaskId"; + private TestRunner runner; + @Mock + private TextractClient mockTextractClient; + + private StartAwsTextractJob processor; + + private ObjectMapper objectMapper = JsonMapper.builder() + .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true) + .build(); + @Captor + private ArgumentCaptor documentAnalysisCaptor; + @Captor + private ArgumentCaptor expenseAnalysisRequestCaptor; + @Captor + private ArgumentCaptor documentTextDetectionCaptor; + + private TestRunner createRunner(final StartAwsTextractJob processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } + + @BeforeEach + public void setUp() throws InitializationException { + processor = new StartAwsTextractJob() { + @Override + public TextractClient getClient(ProcessContext context) { + return mockTextractClient; + } + }; + runner = createRunner(processor); + } + + @Test + public void testSuccessfulDocumentAnalysisFlowfileContent() throws JsonProcessingException { + final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder() + .jobTag("Tag") + .build(); + final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType()); + runner.enqueue(requestJson); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData); + + assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testSuccessfulDocumentAnalysisAttribute() throws JsonProcessingException { + final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder() + .jobTag("Tag") + .build(); + final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType()); + runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}"); + final Map attributes = new HashMap<>(); + attributes.put("json.payload", requestJson); + runner.enqueue("", attributes); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData); + + assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testInvalidDocumentAnalysisJson() { + final String requestJson = "invalid"; + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testDocumentAnalysisServiceFailure() throws JsonProcessingException { + final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder() + .jobTag("Tag") + .build(); + when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build()); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testSuccessfulExpenseAnalysisFlowfileContent() throws JsonProcessingException { + final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder() + .jobTag("Tag") + .build(); + final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType()); + runner.enqueue(requestJson); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData); + + assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testSuccessfulExpenseAnalysisAttribute() throws JsonProcessingException { + final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder() + .jobTag("Tag") + .build(); + final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType()); + runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}"); + final Map attributes = new HashMap<>(); + attributes.put("json.payload", requestJson); + runner.enqueue("", attributes); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData); + + assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testInvalidExpenseAnalysisJson() { + final String requestJson = "invalid"; + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType()); + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testExpenseAnalysisServiceFailure() throws JsonProcessingException { + final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder() + .jobTag("Tag") + .build(); + when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build()); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType()); + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testSuccessfulDocumentTextDetectionFlowfileContent() throws JsonProcessingException { + final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder() + .jobTag("Tag") + .build(); + final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType()); + runner.enqueue(requestJson); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData); + + assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testSuccessfulDocumentTextDetectionAttribute() throws JsonProcessingException { + final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder() + .jobTag("Tag") + .build(); + final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType()); + runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}"); + final Map attributes = new HashMap<>(); + attributes.put("json.payload", requestJson); + runner.enqueue("", attributes); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData); + + assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testInvalidDocumentTextDetectionJson() { + final String requestJson = "invalid"; + runner.enqueue(requestJson); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType()); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testDocumentTextDetectionServiceFailure() throws JsonProcessingException { + final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder() + .jobTag("Tag") + .build(); + when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build()); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType()); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + private StartDocumentTextDetectionResponse deserializeDTDRequest(final String responseData) throws JsonProcessingException { + return objectMapper.readValue(responseData, StartDocumentTextDetectionResponse.serializableBuilderClass()).build(); + } + + private StartDocumentAnalysisResponse deserializeDARequest(final String responseData) throws JsonProcessingException { + return objectMapper.readValue(responseData, StartDocumentAnalysisResponse.serializableBuilderClass()).build(); + } + + private StartExpenseAnalysisResponse deserializeEARequest(final String responseData) throws JsonProcessingException { + return objectMapper.readValue(responseData, StartExpenseAnalysisResponse.serializableBuilderClass()).build(); + } + + private String serialize(final StartDocumentAnalysisRequest request) throws JsonProcessingException { + return objectMapper.writeValueAsString(request.toBuilder()); + } + + private String serialize(final StartExpenseAnalysisRequest request) throws JsonProcessingException { + return objectMapper.writeValueAsString(request.toBuilder()); + } + + private String serialize(final StartDocumentTextDetectionRequest request) throws JsonProcessingException { + return objectMapper.writeValueAsString(request.toBuilder()); + } +} \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java index 68a253dc7a..9e139d2c31 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/GetAwsTranscribeJobStatusTest.java @@ -17,19 +17,11 @@ package org.apache.nifi.processors.aws.ml.transcribe; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.transcribe.AmazonTranscribeClient; -import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest; -import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult; -import com.amazonaws.services.transcribe.model.Transcript; -import com.amazonaws.services.transcribe.model.TranscriptionJob; -import com.amazonaws.services.transcribe.model.TranscriptionJobStatus; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; +import org.apache.nifi.processor.Relationship; +import org.apache.nifi.processors.aws.testutil.AuthUtils; import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.junit.jupiter.api.BeforeEach; @@ -39,93 +31,118 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.services.transcribe.TranscribeClient; +import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest; +import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse; +import software.amazon.awssdk.services.transcribe.model.Transcript; +import software.amazon.awssdk.services.transcribe.model.TranscriptionJob; +import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus; import java.util.Collections; -import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class GetAwsTranscribeJobStatusTest { - private static final String TEST_TASK_ID = "testTaskId"; - private static final String AWS_CREDENTIAL_PROVIDER_NAME = "awsCredetialProvider"; + private static final String TEST_TASK_ID = "testJobId"; private static final String OUTPUT_LOCATION_PATH = "outputLocationPath"; private static final String REASON_OF_FAILURE = "reasonOfFailure"; private static final String CONTENT_STRING = "content"; private TestRunner runner; @Mock - private AmazonTranscribeClient mockTranscribeClient; - @Mock - private AWSCredentialsProviderService mockAwsCredentialsProvider; + private TranscribeClient mockTranscribeClient; + + private GetAwsTranscribeJobStatus processor; + @Captor private ArgumentCaptor requestCaptor; + private TestRunner createRunner(final GetAwsTranscribeJobStatus processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } + @BeforeEach public void setUp() throws InitializationException { - when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIAL_PROVIDER_NAME); - final GetAwsTranscribeJobStatus mockPollyFetcher = new GetAwsTranscribeJobStatus() { + processor = new GetAwsTranscribeJobStatus() { @Override - protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { + public TranscribeClient getClient(final ProcessContext context) { return mockTranscribeClient; } }; - runner = TestRunners.newTestRunner(mockPollyFetcher); - runner.addControllerService(AWS_CREDENTIAL_PROVIDER_NAME, mockAwsCredentialsProvider); - runner.enableControllerService(mockAwsCredentialsProvider); - runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIAL_PROVIDER_NAME); + runner = createRunner(processor); } @Test - public void testTranscribeTaskInProgress() { - TranscriptionJob task = new TranscriptionJob() - .withTranscriptionJobName(TEST_TASK_ID) - .withTranscriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS); - GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task); - when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); - runner.run(); - - runner.assertAllFlowFilesTransferred(REL_RUNNING); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName()); + public void testTranscribeJobInProgress() { + final TranscriptionJob job = TranscriptionJob.builder() + .transcriptionJobName(TEST_TASK_ID) + .transcriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS) + .build(); + testTranscribeJob(job, REL_RUNNING); } @Test - public void testTranscribeTaskCompleted() { - TranscriptionJob task = new TranscriptionJob() - .withTranscriptionJobName(TEST_TASK_ID) - .withTranscript(new Transcript().withTranscriptFileUri(OUTPUT_LOCATION_PATH)) - .withTranscriptionJobStatus(TranscriptionJobStatus.COMPLETED); - GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task); - when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); - runner.run(); - - runner.assertAllFlowFilesTransferred(REL_SUCCESS); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName()); + public void testTranscribeJobQueued() { + final TranscriptionJob job = TranscriptionJob.builder() + .transcriptionJobName(TEST_TASK_ID) + .transcriptionJobStatus(TranscriptionJobStatus.QUEUED) + .build(); + testTranscribeJob(job, REL_RUNNING); } + @Test + public void testTranscribeJobCompleted() { + final TranscriptionJob job = TranscriptionJob.builder() + .transcriptionJobName(TEST_TASK_ID) + .transcript(Transcript.builder().transcriptFileUri(OUTPUT_LOCATION_PATH).build()) + .transcriptionJobStatus(TranscriptionJobStatus.COMPLETED) + .build(); + testTranscribeJob(job, REL_SUCCESS); + runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION); + final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next(); + assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION)); + } @Test - public void testPollyTaskFailed() { - TranscriptionJob task = new TranscriptionJob() - .withTranscriptionJobName(TEST_TASK_ID) - .withFailureReason(REASON_OF_FAILURE) - .withTranscriptionJobStatus(TranscriptionJobStatus.FAILED); - GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task); - when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); - runner.run(); - - runner.assertAllFlowFilesTransferred(REL_FAILURE); + public void testTranscribeJobFailed() { + final TranscriptionJob job = TranscriptionJob.builder() + .transcriptionJobName(TEST_TASK_ID) + .failureReason(REASON_OF_FAILURE) + .transcriptionJobStatus(TranscriptionJobStatus.FAILED) + .build(); + testTranscribeJob(job, REL_FAILURE); runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName()); + final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next(); + assertEquals(REASON_OF_FAILURE, flowFile.getAttribute(FAILURE_REASON_ATTRIBUTE)); + } + @Test + public void testTranscribeJobUnrecognized() { + final TranscriptionJob job = TranscriptionJob.builder() + .transcriptionJobName(TEST_TASK_ID) + .failureReason(REASON_OF_FAILURE) + .transcriptionJobStatus(TranscriptionJobStatus.UNKNOWN_TO_SDK_VERSION) + .build(); + testTranscribeJob(job, REL_FAILURE); + runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE); + } + + private void testTranscribeJob(final TranscriptionJob job, final Relationship expectedRelationship) { + final GetTranscriptionJobResponse response = GetTranscriptionJobResponse.builder().transcriptionJob(job).build(); + when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(response); + runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); + runner.run(); + + runner.assertAllFlowFilesTransferred(expectedRelationship); + assertEquals(TEST_TASK_ID, requestCaptor.getValue().transcriptionJobName()); } } \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJobTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJobTest.java new file mode 100644 index 0000000000..1da8247162 --- /dev/null +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/transcribe/StartAwsTranscribeJobTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.processors.aws.ml.transcribe; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processors.aws.testutil.AuthUtils; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.transcribe.TranscribeClient; +import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest; +import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse; +import software.amazon.awssdk.services.transcribe.model.TranscriptionJob; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class StartAwsTranscribeJobTest { + private static final String TEST_TASK_ID = "testTaskId"; + private TestRunner runner; + @Mock + private TranscribeClient mockTranscribeClient; + + private StartAwsTranscribeJob processor; + + private ObjectMapper objectMapper = JsonMapper.builder() + .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true) + .build(); + @Captor + private ArgumentCaptor requestCaptor; + + private TestRunner createRunner(final StartAwsTranscribeJob processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } + + @BeforeEach + public void setUp() throws InitializationException { + processor = new StartAwsTranscribeJob() { + @Override + public TranscribeClient getClient(ProcessContext context) { + return mockTranscribeClient; + } + }; + runner = createRunner(processor); + } + + @Test + public void testSuccessfulFlowfileContent() throws JsonProcessingException { + final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder() + .transcriptionJobName("Job") + .build(); + final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder() + .transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build()) + .build(); + when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartTranscriptionJobResponse parsedResponse = deserialize(responseData); + + assertEquals("Job", requestCaptor.getValue().transcriptionJobName()); + assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName()); + } + + @Test + public void testSuccessfulAttribute() throws JsonProcessingException { + final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder() + .transcriptionJobName("Job") + .build(); + final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder() + .transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build()) + .build(); + when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTranscribeJob.JSON_PAYLOAD, "${json.payload}"); + final Map attributes = new HashMap<>(); + attributes.put("json.payload", requestJson); + runner.enqueue("", attributes); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartTranscriptionJobResponse parsedResponse = deserialize(responseData); + + assertEquals("Job", requestCaptor.getValue().transcriptionJobName()); + assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName()); + } + + @Test + public void testInvalidJson() { + final String requestJson = "invalid"; + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testServiceFailure() throws JsonProcessingException { + final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder() + .transcriptionJobName("Job") + .build(); + when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build()); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + private StartTranscriptionJobResponse deserialize(final String responseData) throws JsonProcessingException { + return objectMapper.readValue(responseData, StartTranscriptionJobResponse.serializableBuilderClass()).build(); + } + + private String serialize(final StartTranscriptionJobRequest request) throws JsonProcessingException { + return objectMapper.writeValueAsString(request.toBuilder()); + } +} \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java index 7c868c70ce..dd3b701348 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/GetAwsTranslateJobStatusTest.java @@ -17,19 +17,11 @@ package org.apache.nifi.processors.aws.ml.translate; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.regions.Region; -import com.amazonaws.services.translate.AmazonTranslateClient; -import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest; -import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult; -import com.amazonaws.services.translate.model.JobStatus; -import com.amazonaws.services.translate.model.OutputDataConfig; -import com.amazonaws.services.translate.model.TextTranslationJobProperties; import org.apache.nifi.processor.ProcessContext; -import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; +import org.apache.nifi.processor.Relationship; +import org.apache.nifi.processors.aws.testutil.AuthUtils; import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.MockFlowFile; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.junit.jupiter.api.BeforeEach; @@ -39,91 +31,132 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.services.translate.TranslateClient; +import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest; +import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse; +import software.amazon.awssdk.services.translate.model.JobStatus; +import software.amazon.awssdk.services.translate.model.OutputDataConfig; +import software.amazon.awssdk.services.translate.model.TextTranslationJobProperties; +import java.time.Instant; import java.util.Collections; -import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; -import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) public class GetAwsTranslateJobStatusTest { - private static final String TEST_TASK_ID = "testTaskId"; + private static final String TEST_TASK_ID = "testJobId"; + private static final String OUTPUT_LOCATION_PATH = "outputLocationPath"; + private static final String REASON_OF_FAILURE = "reasonOfFailure"; private static final String CONTENT_STRING = "content"; - private static final String AWS_CREDENTIALS_PROVIDER_NAME = "awsCredetialProvider"; - private static final String OUTPUT_LOCATION_PATH = "outputLocation"; private TestRunner runner; @Mock - private AmazonTranslateClient mockTranslateClient; - @Mock - private AWSCredentialsProviderService mockAwsCredentialsProvider; + private TranslateClient mockTranslateClient; + + private GetAwsTranslateJobStatus processor; + @Captor private ArgumentCaptor requestCaptor; + private TestRunner createRunner(final GetAwsTranslateJobStatus processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } @BeforeEach public void setUp() throws InitializationException { - when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIALS_PROVIDER_NAME); - final GetAwsTranslateJobStatus mockPollyFetcher = new GetAwsTranslateJobStatus() { + processor = new GetAwsTranslateJobStatus() { @Override - protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, - final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { + public TranslateClient getClient(final ProcessContext context) { return mockTranslateClient; } }; - runner = TestRunners.newTestRunner(mockPollyFetcher); - runner.addControllerService(AWS_CREDENTIALS_PROVIDER_NAME, mockAwsCredentialsProvider); - runner.enableControllerService(mockAwsCredentialsProvider); - runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIALS_PROVIDER_NAME); + runner = createRunner(processor); } @Test - public void testTranscribeTaskInProgress() { - TextTranslationJobProperties task = new TextTranslationJobProperties() - .withJobId(TEST_TASK_ID) - .withJobStatus(JobStatus.IN_PROGRESS); - DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task); - when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); - runner.run(); - - runner.assertAllFlowFilesTransferred(REL_RUNNING); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); + public void testTranslateJobInProgress() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobId(TEST_TASK_ID) + .jobStatus(JobStatus.IN_PROGRESS) + .build(); + testTranslateJob(job, REL_RUNNING); } @Test - public void testTranscribeTaskCompleted() { - TextTranslationJobProperties task = new TextTranslationJobProperties() - .withJobId(TEST_TASK_ID) - .withOutputDataConfig(new OutputDataConfig().withS3Uri(OUTPUT_LOCATION_PATH)) - .withJobStatus(JobStatus.COMPLETED); - DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task); - when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult); - runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); - runner.run(); + public void testTranslateSubmitted() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobId(TEST_TASK_ID) + .jobStatus(JobStatus.SUBMITTED) + .build(); + testTranslateJob(job, REL_RUNNING); + } - runner.assertAllFlowFilesTransferred(REL_SUCCESS); + @Test + public void testTranslateStopRequested() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobId(TEST_TASK_ID) + .jobStatus(JobStatus.STOP_REQUESTED) + .build(); + testTranslateJob(job, REL_RUNNING); + } + + @Test + public void testTranslateJobCompleted() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobStatus(TEST_TASK_ID) + .outputDataConfig(OutputDataConfig.builder().s3Uri(OUTPUT_LOCATION_PATH).build()) + .submittedTime(Instant.now()) + .jobStatus(JobStatus.COMPLETED) + .build(); + testTranslateJob(job, REL_SUCCESS); runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); + final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next(); + assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION)); } @Test - public void testTranscribeTaskFailed() { - TextTranslationJobProperties task = new TextTranslationJobProperties() - .withJobId(TEST_TASK_ID) - .withJobStatus(JobStatus.FAILED); - DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task); - when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult); + public void testTranslateJobFailed() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobStatus(TEST_TASK_ID) + .jobStatus(JobStatus.FAILED) + .build(); + testTranslateJob(job, REL_FAILURE); + } + + @Test + public void testTranslateJobStopped() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobStatus(TEST_TASK_ID) + .jobStatus(JobStatus.STOPPED) + .build(); + testTranslateJob(job, REL_FAILURE); + } + + @Test + public void testTranslateJobUnrecognized() { + final TextTranslationJobProperties job = TextTranslationJobProperties.builder() + .jobStatus(TEST_TASK_ID) + .jobStatus(JobStatus.UNKNOWN_TO_SDK_VERSION) + .build(); + testTranslateJob(job, REL_FAILURE); + } + + private void testTranslateJob(final TextTranslationJobProperties job, final Relationship expectedRelationship) { + final DescribeTextTranslationJobResponse response = DescribeTextTranslationJobResponse.builder().textTranslationJobProperties(job).build(); + when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(response); runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.run(); - runner.assertAllFlowFilesTransferred(REL_FAILURE); - assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); + runner.assertAllFlowFilesTransferred(expectedRelationship); + assertEquals(TEST_TASK_ID, requestCaptor.getValue().jobId()); } } \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJobTest.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJobTest.java new file mode 100644 index 0000000000..d8e1568c1d --- /dev/null +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/ml/translate/StartAwsTranslateJobTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.processors.aws.ml.translate; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.json.JsonMapper; +import org.apache.nifi.processor.ProcessContext; +import org.apache.nifi.processors.aws.testutil.AuthUtils; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.translate.TranslateClient; +import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest; +import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL; +import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class StartAwsTranslateJobTest { + private static final String TEST_TASK_ID = "testTaskId"; + private TestRunner runner; + @Mock + private TranslateClient mockTranslateClient; + + private StartAwsTranslateJob processor; + + private ObjectMapper objectMapper = JsonMapper.builder() + .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true) + .build(); + @Captor + private ArgumentCaptor requestCaptor; + + private TestRunner createRunner(final StartAwsTranslateJob processor) { + final TestRunner runner = TestRunners.newTestRunner(processor); + AuthUtils.enableAccessKey(runner, "abcd", "defg"); + return runner; + } + + @BeforeEach + public void setUp() throws InitializationException { + processor = new StartAwsTranslateJob() { + @Override + public TranslateClient getClient(ProcessContext context) { + return mockTranslateClient; + } + }; + runner = createRunner(processor); + } + + @Test + public void testSuccessfulFlowfileContent() throws JsonProcessingException { + final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder() + .terminologyNames("Name") + .build(); + final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartTextTranslationJobResponse parsedResponse = deserialize(responseData); + + assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testSuccessfulAttribute() throws JsonProcessingException { + final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder() + .terminologyNames("Name") + .build(); + final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder() + .jobId(TEST_TASK_ID) + .build(); + when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response); + + final String requestJson = serialize(request); + runner.setProperty(StartAwsTranslateJob.JSON_PAYLOAD, "${json.payload}"); + final Map attributes = new HashMap<>(); + attributes.put("json.payload", requestJson); + runner.enqueue("", attributes); + runner.run(); + + runner.assertTransferCount(REL_SUCCESS, 1); + runner.assertTransferCount(REL_ORIGINAL, 1); + final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent(); + final StartTextTranslationJobResponse parsedResponse = deserialize(responseData); + + assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames()); + assertEquals(TEST_TASK_ID, parsedResponse.jobId()); + } + + @Test + public void testInvalidJson() { + final String requestJson = "invalid"; + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + @Test + public void testServiceFailure() throws JsonProcessingException { + final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder() + .terminologyNames("Name") + .build(); + when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build()); + + final String requestJson = serialize(request); + runner.enqueue(requestJson); + runner.run(); + + runner.assertAllFlowFilesTransferred(REL_FAILURE, 1); + } + + private StartTextTranslationJobResponse deserialize(final String responseData) throws JsonProcessingException { + return objectMapper.readValue(responseData, StartTextTranslationJobResponse.serializableBuilderClass()).build(); + } + + private String serialize(final StartTextTranslationJobRequest request) throws JsonProcessingException { + return objectMapper.writeValueAsString(request.toBuilder()); + } +} \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java index 4e93560297..c0c7981b79 100644 --- a/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java +++ b/nifi-nar-bundles/nifi-aws-bundle/nifi-aws-processors/src/test/java/org/apache/nifi/processors/aws/sqs/AbstractSQSIT.java @@ -66,7 +66,6 @@ public abstract class AbstractSQSIT { queueUrl = response.queueUrl(); } - @AfterAll public static void shutdown() { client.close();