NIFI-12263 Upgraded AWS Machine Learning processors to SDK 2

This closes #7953

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Joe Gresock 2023-10-28 11:40:59 -04:00 committed by exceptionfactory
parent c706877147
commit 77834c92df
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
21 changed files with 1542 additions and 608 deletions

View File

@ -65,6 +65,22 @@
<groupId>software.amazon.awssdk</groupId> <groupId>software.amazon.awssdk</groupId>
<artifactId>firehose</artifactId> <artifactId>firehose</artifactId>
</dependency> </dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>polly</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>textract</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>transcribe</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>translate</artifactId>
</dependency>
<dependency> <dependency>
<groupId>software.amazon.kinesis</groupId> <groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client</artifactId> <artifactId>amazon-kinesis-client</artifactId>
@ -130,6 +146,10 @@
<groupId>com.github.ben-manes.caffeine</groupId> <groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId> <artifactId>caffeine</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
</dependency>
<!-- Version 2 of the AmazonS3EncryptionClient requires bouncy castle --> <!-- Version 2 of the AmazonS3EncryptionClient requires bouncy castle -->
<dependency> <dependency>
<groupId>org.bouncycastle</groupId> <groupId>org.bouncycastle</groupId>

View File

@ -17,14 +17,6 @@
package org.apache.nifi.processors.aws.ml; package org.apache.nifi.processors.aws.ml;
import com.amazonaws.AmazonWebServiceClient;
import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.AmazonWebServiceResult;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature; import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
@ -33,12 +25,18 @@ import org.apache.commons.io.IOUtils;
import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor; import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
import software.amazon.awssdk.awscore.AwsRequest;
import software.amazon.awssdk.awscore.AwsResponse;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.core.SdkClient;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -46,10 +44,15 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import static org.apache.nifi.flowfile.attributes.CoreAttributes.MIME_TYPE; import static org.apache.nifi.flowfile.attributes.CoreAttributes.MIME_TYPE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceClient, REQUEST extends AmazonWebServiceRequest, RESPONSE extends AmazonWebServiceResult> public abstract class AbstractAwsMachineLearningJobStarter<
extends AbstractAWSCredentialsProviderProcessor<T> { Q extends AwsRequest,
B extends AwsRequest.Builder,
R extends AwsResponse,
T extends SdkClient,
U extends AwsSyncClientBuilder<U, T> & AwsClientBuilder<U, T>>
extends AbstractAwsSyncProcessor<T, U> {
public static final PropertyDescriptor JSON_PAYLOAD = new PropertyDescriptor.Builder() public static final PropertyDescriptor JSON_PAYLOAD = new PropertyDescriptor.Builder()
.name("json-payload") .name("json-payload")
.displayName("JSON Payload") .displayName("JSON Payload")
@ -62,18 +65,17 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE) new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
.required(true) .required(true)
.build(); .build();
public static final PropertyDescriptor REGION = new PropertyDescriptor.Builder()
.displayName("Region")
.name("aws-region")
.required(true)
.allowableValues(getAvailableRegions())
.defaultValue(createAllowableValue(Regions.DEFAULT_REGION).getValue())
.build();
public static final Relationship REL_ORIGINAL = new Relationship.Builder() public static final Relationship REL_ORIGINAL = new Relationship.Builder()
.name("original") .name("original")
.description("Upon successful completion, the original FlowFile will be routed to this relationship.") .description("Upon successful completion, the original FlowFile will be routed to this relationship.")
.autoTerminateDefault(true) .autoTerminateDefault(true)
.build(); .build();
@Override
public void migrateProperties(final PropertyConfiguration config) {
config.renameProperty("aws-region", REGION.getName());
}
protected static final List<PropertyDescriptor> PROPERTIES = List.of( protected static final List<PropertyDescriptor> PROPERTIES = List.of(
MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE, MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE,
REGION, REGION,
@ -84,10 +86,9 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
private final static ObjectMapper MAPPER = JsonMapper.builder() private final static ObjectMapper MAPPER = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true) .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.findAndAddModules()
.build(); .build();
private static final Set<Relationship> relationships = Set.of(REL_ORIGINAL, private static final Set<Relationship> relationships = Set.of(REL_ORIGINAL, REL_SUCCESS, REL_FAILURE);
REL_SUCCESS,
REL_FAILURE);
@Override @Override
public Set<Relationship> getRelationships() { public Set<Relationship> getRelationships() {
@ -105,14 +106,14 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
if (flowFile == null && !context.getProperty(JSON_PAYLOAD).isSet()) { if (flowFile == null && !context.getProperty(JSON_PAYLOAD).isSet()) {
return; return;
} }
final RESPONSE response; final R response;
FlowFile childFlowFile; FlowFile childFlowFile;
try { try {
response = sendRequest(buildRequest(session, context, flowFile), context, flowFile); response = sendRequest(buildRequest(session, context, flowFile), context, flowFile);
childFlowFile = writeToFlowFile(session, flowFile, response); childFlowFile = writeToFlowFile(session, flowFile, response);
postProcessFlowFile(context, session, childFlowFile, response); childFlowFile = postProcessFlowFile(context, session, childFlowFile, response);
session.transfer(childFlowFile, REL_SUCCESS); session.transfer(childFlowFile, REL_SUCCESS);
} catch (Exception e) { } catch (final Exception e) {
if (flowFile != null) { if (flowFile != null) {
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
} }
@ -125,26 +126,21 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
} }
protected void postProcessFlowFile(ProcessContext context, ProcessSession session, FlowFile flowFile, RESPONSE response) { protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, final FlowFile flowFile, final R response) {
final String awsTaskId = getAwsTaskId(context, response, flowFile); final String awsTaskId = getAwsTaskId(context, response, flowFile);
flowFile = session.putAttribute(flowFile, TASK_ID.getName(), awsTaskId); FlowFile processedFlowFile = session.putAttribute(flowFile, TASK_ID.getName(), awsTaskId);
flowFile = session.putAttribute(flowFile, MIME_TYPE.key(), "application/json"); processedFlowFile = session.putAttribute(processedFlowFile, MIME_TYPE.key(), "application/json");
getLogger().debug("AWS ML Task [{}] started", awsTaskId); getLogger().debug("AWS ML Task [{}] started", awsTaskId);
return processedFlowFile;
} }
protected REQUEST buildRequest(ProcessSession session, ProcessContext context, FlowFile flowFile) throws JsonProcessingException { protected Q buildRequest(final ProcessSession session, final ProcessContext context, final FlowFile flowFile) throws JsonProcessingException {
return MAPPER.readValue(getPayload(session, context, flowFile), getAwsRequestClass(context, flowFile)); return (Q) MAPPER.readValue(getPayload(session, context, flowFile), getAwsRequestBuilderClass(context, flowFile)).build();
} }
@Override protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final R response) {
protected T createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
throw new UnsupportedOperationException();
}
protected FlowFile writeToFlowFile(ProcessSession session, FlowFile flowFile, RESPONSE response) {
FlowFile childFlowFile = flowFile == null ? session.create() : session.create(flowFile); FlowFile childFlowFile = flowFile == null ? session.create() : session.create(flowFile);
childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response)); childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
return childFlowFile; return childFlowFile;
} }
@ -156,7 +152,7 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
} }
} }
private String getPayload(ProcessSession session, ProcessContext context, FlowFile flowFile) { private String getPayload(final ProcessSession session, final ProcessContext context, final FlowFile flowFile) {
String payloadPropertyValue = context.getProperty(JSON_PAYLOAD).evaluateAttributeExpressions(flowFile).getValue(); String payloadPropertyValue = context.getProperty(JSON_PAYLOAD).evaluateAttributeExpressions(flowFile).getValue();
if (payloadPropertyValue == null) { if (payloadPropertyValue == null) {
payloadPropertyValue = readFlowFile(session, flowFile); payloadPropertyValue = readFlowFile(session, flowFile);
@ -164,9 +160,9 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
return payloadPropertyValue; return payloadPropertyValue;
} }
abstract protected RESPONSE sendRequest(REQUEST request, ProcessContext context, FlowFile flowFile) throws JsonProcessingException; abstract protected R sendRequest(Q request, ProcessContext context, FlowFile flowFile) throws JsonProcessingException;
abstract protected Class<? extends REQUEST> getAwsRequestClass(ProcessContext context, FlowFile flowFile); abstract protected Class<? extends B> getAwsRequestBuilderClass(ProcessContext context, FlowFile flowFile);
abstract protected String getAwsTaskId(ProcessContext context, RESPONSE response, FlowFile flowFile); abstract protected String getAwsTaskId(ProcessContext context, R response, FlowFile flowFile);
} }

View File

@ -17,33 +17,30 @@
package org.apache.nifi.processors.aws.ml; package org.apache.nifi.processors.aws.ml;
import com.amazonaws.AmazonWebServiceClient;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.ResponseMetadata;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.http.SdkHttpMetadata;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.databind.MapperFeature; import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper; import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor; import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
import software.amazon.awssdk.awscore.AwsResponse;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.core.SdkClient;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import static org.apache.nifi.expression.ExpressionLanguageScope.FLOWFILE_ATTRIBUTES; import static org.apache.nifi.expression.ExpressionLanguageScope.FLOWFILE_ATTRIBUTES;
public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebServiceClient> public abstract class AbstractAwsMachineLearningJobStatusProcessor<
extends AbstractAWSCredentialsProviderProcessor<T> { T extends SdkClient,
U extends AwsSyncClientBuilder<U, T> & AwsClientBuilder<U, T>>
extends AbstractAwsSyncProcessor<T, U> {
public static final String AWS_TASK_OUTPUT_LOCATION = "outputLocation"; public static final String AWS_TASK_OUTPUT_LOCATION = "outputLocation";
public static final PropertyDescriptor MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE = public static final PropertyDescriptor MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE =
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE) new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
@ -81,13 +78,6 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
.description("The job failed, the original FlowFile will be routed to this relationship.") .description("The job failed, the original FlowFile will be routed to this relationship.")
.autoTerminateDefault(true) .autoTerminateDefault(true)
.build(); .build();
public static final PropertyDescriptor REGION = new PropertyDescriptor.Builder()
.displayName("Region")
.name("aws-region")
.required(true)
.allowableValues(getAvailableRegions())
.defaultValue(createAllowableValue(Regions.DEFAULT_REGION).getValue())
.build();
public static final String FAILURE_REASON_ATTRIBUTE = "failure.reason"; public static final String FAILURE_REASON_ATTRIBUTE = "failure.reason";
protected static final List<PropertyDescriptor> PROPERTIES = List.of( protected static final List<PropertyDescriptor> PROPERTIES = List.of(
TASK_ID, TASK_ID,
@ -99,18 +89,14 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
PROXY_CONFIGURATION_SERVICE); PROXY_CONFIGURATION_SERVICE);
private static final ObjectMapper MAPPER = JsonMapper.builder() private static final ObjectMapper MAPPER = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true) .configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.findAndAddModules()
.build(); .build();
static { @Override
SimpleModule awsResponseModule = new SimpleModule(); public void migrateProperties(final PropertyConfiguration config) {
awsResponseModule.addDeserializer(ResponseMetadata.class, new AwsResponseMetadataDeserializer()); config.renameProperty("aws-region", REGION.getName());
SimpleModule sdkHttpModule = new SimpleModule();
awsResponseModule.addDeserializer(SdkHttpMetadata.class, new SdkHttpMetadataDeserializer());
MAPPER.registerModule(awsResponseModule);
MAPPER.registerModule(sdkHttpModule);
} }
@Override @Override
public Set<Relationship> getRelationships() { public Set<Relationship> getRelationships() {
return relationships; return relationships;
@ -129,13 +115,7 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
return PROPERTIES; return PROPERTIES;
} }
@Override protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final AwsResponse response) {
protected T createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, return session.write(flowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
throw new UnsupportedOperationException();
}
protected void writeToFlowFile(ProcessSession session, FlowFile flowFile, Object response) {
session.write(flowFile, out -> MAPPER.writeValue(out, response));
} }
} }

View File

@ -125,22 +125,6 @@
<groupId>org.apache.nifi</groupId> <groupId>org.apache.nifi</groupId>
<artifactId>nifi-schema-registry-service-api</artifactId> <artifactId>nifi-schema-registry-service-api</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-translate</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-polly</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-transcribe</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-textract</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.bouncycastle</groupId> <groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk18on</artifactId> <artifactId>bcprov-jdk18on</artifactId>

View File

@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.polly; package org.apache.nifi.processors.aws.ml.polly;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.polly.AmazonPollyClient;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
import com.amazonaws.services.polly.model.TaskStatus;
import com.amazonaws.services.textract.model.ThrottlingException;
import org.apache.nifi.annotation.behavior.WritesAttribute; import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
@ -34,9 +25,17 @@ import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.PollyClientBuilder;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
import software.amazon.awssdk.services.polly.model.TaskStatus;
import java.util.HashSet;
import java.util.Set;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -45,10 +44,10 @@ import java.util.regex.Pattern;
@SeeAlso({StartAwsPollyJob.class}) @SeeAlso({StartAwsPollyJob.class})
@WritesAttributes({ @WritesAttributes({
@WritesAttribute(attribute = "PollyS3OutputBucket", description = "The bucket name where polly output will be located."), @WritesAttribute(attribute = "PollyS3OutputBucket", description = "The bucket name where polly output will be located."),
@WritesAttribute(attribute = "PollyS3OutputKey", description = "Object key of polly output."), @WritesAttribute(attribute = "filename", description = "Object key of polly output."),
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.") @WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
}) })
public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonPollyClient> { public class GetAwsPollyJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<PollyClient, PollyClientBuilder> {
private static final String BUCKET = "bucket"; private static final String BUCKET = "bucket";
private static final String KEY = "key"; private static final String KEY = "key";
private static final Pattern S3_PATH = Pattern.compile("https://s3.*amazonaws.com/(?<" + BUCKET + ">[^/]+)/(?<" + KEY + ">.*)"); private static final Pattern S3_PATH = Pattern.compile("https://s3.*amazonaws.com/(?<" + BUCKET + ">[^/]+)/(?<" + KEY + ">.*)");
@ -56,65 +55,66 @@ public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor<A
private static final String AWS_S3_KEY = "filename"; private static final String AWS_S3_KEY = "filename";
@Override @Override
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, public Set<Relationship> getRelationships() {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { final Set<Relationship> parentRelationships = new HashSet<>(super.getRelationships());
return (AmazonPollyClient) AmazonPollyClient.builder() parentRelationships.remove(REL_THROTTLED);
.withCredentials(credentialsProvider) return Set.copyOf(parentRelationships);
.withRegion(context.getProperty(REGION).getValue())
.withEndpointConfiguration(endpointConfiguration)
.withClientConfiguration(config)
.build();
} }
@Override @Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
return PollyClient.builder();
}
@Override
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get(); FlowFile flowFile = session.get();
if (flowFile == null) { if (flowFile == null) {
return; return;
} }
GetSpeechSynthesisTaskResult speechSynthesisTask; final GetSpeechSynthesisTaskResponse speechSynthesisTask;
try { try {
speechSynthesisTask = getSynthesisTask(context, flowFile); speechSynthesisTask = getSynthesisTask(context, flowFile);
} catch (ThrottlingException e) { } catch (final Exception e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
return;
} catch (Exception e) {
getLogger().warn("Failed to get Polly Job status", e); getLogger().warn("Failed to get Polly Job status", e);
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
return; return;
} }
TaskStatus taskStatus = TaskStatus.fromValue(speechSynthesisTask.getSynthesisTask().getTaskStatus()); final TaskStatus taskStatus = speechSynthesisTask.synthesisTask().taskStatus();
if (taskStatus == TaskStatus.InProgress || taskStatus == TaskStatus.Scheduled) { if (taskStatus == TaskStatus.IN_PROGRESS || taskStatus == TaskStatus.SCHEDULED) {
session.penalize(flowFile); flowFile = session.penalize(flowFile);
session.transfer(flowFile, REL_RUNNING); session.transfer(flowFile, REL_RUNNING);
} else if (taskStatus == TaskStatus.Completed) { } else if (taskStatus == TaskStatus.COMPLETED) {
String outputUri = speechSynthesisTask.getSynthesisTask().getOutputUri(); final String outputUri = speechSynthesisTask.synthesisTask().outputUri();
Matcher matcher = S3_PATH.matcher(outputUri); final Matcher matcher = S3_PATH.matcher(outputUri);
if (matcher.find()) { if (matcher.find()) {
session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET)); flowFile = session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY)); flowFile = session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
} }
FlowFile childFlowFile = session.create(flowFile); FlowFile childFlowFile = session.create(flowFile);
writeToFlowFile(session, childFlowFile, speechSynthesisTask); childFlowFile = writeToFlowFile(session, childFlowFile, speechSynthesisTask);
childFlowFile = session.putAttribute(childFlowFile, AWS_TASK_OUTPUT_LOCATION, outputUri); childFlowFile = session.putAttribute(childFlowFile, AWS_TASK_OUTPUT_LOCATION, outputUri);
session.transfer(flowFile, REL_ORIGINAL); session.transfer(flowFile, REL_ORIGINAL);
session.transfer(childFlowFile, REL_SUCCESS); session.transfer(childFlowFile, REL_SUCCESS);
getLogger().info("Amazon Polly Task Completed {}", flowFile); getLogger().info("Amazon Polly Task Completed {}", flowFile);
} else if (taskStatus == TaskStatus.Failed) { } else if (taskStatus == TaskStatus.FAILED) {
final String failureReason = speechSynthesisTask.getSynthesisTask().getTaskStatusReason(); final String failureReason = speechSynthesisTask.synthesisTask().taskStatusReason();
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason); flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Polly Task Failed {} Reason [{}]", flowFile, failureReason); getLogger().error("Amazon Polly Task Failed {} Reason [{}]", flowFile, failureReason);
} else if (taskStatus == TaskStatus.UNKNOWN_TO_SDK_VERSION) {
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Polly Task Failed {} Reason [Unrecognized job status]", flowFile);
} }
} }
private GetSpeechSynthesisTaskResult getSynthesisTask(ProcessContext context, FlowFile flowFile) { private GetSpeechSynthesisTaskResponse getSynthesisTask(final ProcessContext context, final FlowFile flowFile) {
String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
GetSpeechSynthesisTaskRequest request = new GetSpeechSynthesisTaskRequest().withTaskId(taskId); final GetSpeechSynthesisTaskRequest request = GetSpeechSynthesisTaskRequest.builder().taskId(taskId).build();
return getClient(context).getSpeechSynthesisTask(request); return getClient(context).getSpeechSynthesisTask(request);
} }
} }

View File

@ -17,49 +17,45 @@
package org.apache.nifi.processors.aws.ml.polly; package org.apache.nifi.processors.aws.ml.polly;
import com.amazonaws.ClientConfiguration; import org.apache.nifi.annotation.behavior.WritesAttribute;
import com.amazonaws.auth.AWSCredentialsProvider; import org.apache.nifi.annotation.behavior.WritesAttributes;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.polly.AmazonPollyClient;
import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskRequest;
import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskResult;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.PollyClientBuilder;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Polly"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Polly"})
@CapabilityDescription("Trigger a AWS Polly job. It should be followed by GetAwsPollyJobStatus processor in order to monitor job status.") @CapabilityDescription("Trigger a AWS Polly job. It should be followed by GetAwsPollyJobStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsPollyJobStatus")
})
@SeeAlso({GetAwsPollyJobStatus.class}) @SeeAlso({GetAwsPollyJobStatus.class})
public class StartAwsPollyJob extends AwsMachineLearningJobStarter<AmazonPollyClient, StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskResult> { public class StartAwsPollyJob extends AbstractAwsMachineLearningJobStarter<
StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskRequest.Builder, StartSpeechSynthesisTaskResponse, PollyClient, PollyClientBuilder> {
@Override @Override
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { return PollyClient.builder();
return (AmazonPollyClient) AmazonPollyClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
} }
@Override @Override
protected StartSpeechSynthesisTaskResult sendRequest(StartSpeechSynthesisTaskRequest request, ProcessContext context, FlowFile flowFile) { protected StartSpeechSynthesisTaskResponse sendRequest(final StartSpeechSynthesisTaskRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startSpeechSynthesisTask(request); return getClient(context).startSpeechSynthesisTask(request);
} }
@Override @Override
protected Class<? extends StartSpeechSynthesisTaskRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) { protected Class<? extends StartSpeechSynthesisTaskRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
return StartSpeechSynthesisTaskRequest.class; return StartSpeechSynthesisTaskRequest.serializableBuilderClass();
} }
@Override @Override
protected String getAwsTaskId(ProcessContext context, StartSpeechSynthesisTaskResult startSpeechSynthesisTaskResult, FlowFile flowFile) { protected String getAwsTaskId(final ProcessContext context, final StartSpeechSynthesisTaskResponse startSpeechSynthesisTaskResponse, final FlowFile flowFile) {
return startSpeechSynthesisTaskResult.getSynthesisTask().getTaskId(); return startSpeechSynthesisTaskResponse.synthesisTask().taskId();
} }
} }

View File

@ -17,48 +17,61 @@
package org.apache.nifi.processors.aws.ml.textract; package org.apache.nifi.processors.aws.ml.textract;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.AmazonTextractClient;
import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
import com.amazonaws.services.textract.model.GetDocumentTextDetectionRequest;
import com.amazonaws.services.textract.model.GetExpenseAnalysisRequest;
import com.amazonaws.services.textract.model.JobStatus;
import com.amazonaws.services.textract.model.ThrottlingException;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationContext;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.components.Validator;
import org.apache.nifi.expression.ExpressionLanguageScope; import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.TextractClientBuilder;
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.JobStatus;
import software.amazon.awssdk.services.textract.model.TextractResponse;
import software.amazon.awssdk.services.textract.model.ThrottlingException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS; import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
@CapabilityDescription("Retrieves the current status of an AWS Textract job.") @CapabilityDescription("Retrieves the current status of an AWS Textract job.")
@SeeAlso({StartAwsTextractJob.class}) @SeeAlso({StartAwsTextractJob.class})
public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTextractClient> { public class GetAwsTextractJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TextractClient, TextractClientBuilder> {
public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
@Override
public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
} else if (TextractType.TEXTRACT_TYPES.contains(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
} else {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
}
}
};
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder() public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
.name("textract-type") .name("textract-type")
.displayName("Textract Type") .displayName("Textract Type")
.required(true) .required(true)
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"") .description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
.allowableValues(TextractType.TEXTRACT_TYPES)
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) .expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
.defaultValue(DOCUMENT_ANALYSIS.getType()) .defaultValue(String.format("${%s}", TEXTRACT_TYPE_ATTRIBUTE))
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR) .addValidator(TEXTRACT_TYPE_VALIDATOR)
.build(); .build();
private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES = private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES =
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList())); Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
@ -69,30 +82,24 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
} }
@Override @Override
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { return TextractClient.builder();
return (AmazonTextractClient) AmazonTextractClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.withCredentials(credentialsProvider)
.build();
} }
@Override @Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get(); FlowFile flowFile = session.get();
if (flowFile == null) { if (flowFile == null) {
return; return;
} }
String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue(); final String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); final String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
try { try {
JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId); final JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
if (JobStatus.SUCCEEDED == jobStatus) { if (JobStatus.SUCCEEDED == jobStatus) {
Object task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId); final TextractResponse task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
writeToFlowFile(session, flowFile, task); flowFile = writeToFlowFile(session, flowFile, task);
session.transfer(flowFile, REL_SUCCESS); session.transfer(flowFile, REL_SUCCESS);
} else if (JobStatus.IN_PROGRESS == jobStatus) { } else if (JobStatus.IN_PROGRESS == jobStatus) {
session.transfer(flowFile, REL_RUNNING); session.transfer(flowFile, REL_RUNNING);
@ -101,29 +108,31 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
} else if (JobStatus.FAILED == jobStatus) { } else if (JobStatus.FAILED == jobStatus) {
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Textract Task [{}] Failed", awsTaskId); getLogger().error("Amazon Textract Task [{}] Failed", awsTaskId);
} else {
throw new IllegalStateException("Unrecognized job status");
} }
} catch (ThrottlingException e) { } catch (final ThrottlingException e) {
getLogger().info("Request Rate Limit exceeded", e); getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED); session.transfer(flowFile, REL_THROTTLED);
} catch (Exception e) { } catch (final Exception e) {
getLogger().warn("Failed to get Textract Job status", e); getLogger().warn("Failed to get Textract Job status", e);
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
} }
} }
private Object getTask(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) { private TextractResponse getTask(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
return switch (typeOfTextract) { return switch (typeOfTextract) {
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)); case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build());
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)); case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build());
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)); case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build());
}; };
} }
private JobStatus getTaskStatus(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) { private JobStatus getTaskStatus(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
return switch (typeOfTextract) { return switch (typeOfTextract) {
case DOCUMENT_ANALYSIS -> JobStatus.fromValue(client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)).getJobStatus()); case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
case DOCUMENT_TEXT_DETECTION -> JobStatus.fromValue(client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)).getJobStatus()); case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build()).jobStatus();
case EXPENSE_ANALYSIS -> JobStatus.fromValue(client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)).getJobStatus()); case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
}; };
} }
} }

View File

@ -17,31 +17,26 @@
package org.apache.nifi.processors.aws.ml.textract; package org.apache.nifi.processors.aws.ml.textract;
import com.amazonaws.AmazonWebServiceRequest; import org.apache.nifi.annotation.behavior.WritesAttribute;
import com.amazonaws.AmazonWebServiceResult; import org.apache.nifi.annotation.behavior.WritesAttributes;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.AmazonTextractClient;
import com.amazonaws.services.textract.model.StartDocumentAnalysisRequest;
import com.amazonaws.services.textract.model.StartDocumentAnalysisResult;
import com.amazonaws.services.textract.model.StartDocumentTextDetectionRequest;
import com.amazonaws.services.textract.model.StartDocumentTextDetectionResult;
import com.amazonaws.services.textract.model.StartExpenseAnalysisRequest;
import com.amazonaws.services.textract.model.StartExpenseAnalysisResult;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationContext;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.components.Validator;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.TextractClientBuilder;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
import software.amazon.awssdk.services.textract.model.TextractRequest;
import software.amazon.awssdk.services.textract.model.TextractResponse;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -52,28 +47,23 @@ import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_A
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
@CapabilityDescription("Trigger a AWS Textract job. It should be followed by GetAwsTextractJobStatus processor in order to monitor job status.") @CapabilityDescription("Trigger a AWS Textract job. It should be followed by GetAwsTextractJobStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTextractJobStatus"),
@WritesAttribute(attribute = "awsTextractType", description = "The selected Textract type, which can be used in GetAwsTextractJobStatus")
})
@SeeAlso({GetAwsTextractJobStatus.class}) @SeeAlso({GetAwsTextractJobStatus.class})
public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonTextractClient, AmazonWebServiceRequest, AmazonWebServiceResult> { public class StartAwsTextractJob extends AbstractAwsMachineLearningJobStarter<
public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() { TextractRequest, TextractRequest.Builder, TextractResponse, TextractClient, TextractClientBuilder> {
@Override
public ValidationResult validate(final String subject, final String value, final ValidationContext context) { public static final String TEXTRACT_TYPE_ATTRIBUTE = "awsTextractType";
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
} else if (TextractType.TEXTRACT_TYPES.contains(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
} else {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
}
}
};
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder() public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
.name("textract-type") .name("textract-type")
.displayName("Textract Type") .displayName("Textract Type")
.required(true) .required(true)
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"") .description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES) .allowableValues(TextractType.TEXTRACT_TYPES)
.defaultValue(DOCUMENT_ANALYSIS.type) .defaultValue(DOCUMENT_ANALYSIS.getType())
.addValidator(TEXTRACT_TYPE_VALIDATOR)
.build(); .build();
private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES = private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES =
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList())); Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
@ -84,24 +74,13 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonText
} }
@Override @Override
protected void postProcessFlowFile(ProcessContext context, ProcessSession session, FlowFile flowFile, AmazonWebServiceResult response) { protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
super.postProcessFlowFile(context, session, flowFile, response); return TextractClient.builder();
} }
@Override @Override
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected TextractResponse sendRequest(final TextractRequest request, final ProcessContext context, final FlowFile flowFile) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return (AmazonTextractClient) AmazonTextractClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
}
@Override
protected AmazonWebServiceResult sendRequest(AmazonWebServiceRequest request, ProcessContext context, FlowFile flowFile) {
TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
return switch (textractType) { return switch (textractType) {
case DOCUMENT_ANALYSIS -> getClient(context).startDocumentAnalysis((StartDocumentAnalysisRequest) request); case DOCUMENT_ANALYSIS -> getClient(context).startDocumentAnalysis((StartDocumentAnalysisRequest) request);
case DOCUMENT_TEXT_DETECTION -> getClient(context).startDocumentTextDetection((StartDocumentTextDetectionRequest) request); case DOCUMENT_TEXT_DETECTION -> getClient(context).startDocumentTextDetection((StartDocumentTextDetectionRequest) request);
@ -110,22 +89,28 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonText
} }
@Override @Override
protected Class<? extends AmazonWebServiceRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) { protected Class<? extends TextractRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue()); final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (typeOfTextract) { return switch (typeOfTextract) {
case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.class; case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.serializableBuilderClass();
case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.class; case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.serializableBuilderClass();
case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.class; case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.serializableBuilderClass();
}; };
} }
@Override @Override
protected String getAwsTaskId(ProcessContext context, AmazonWebServiceResult amazonWebServiceResult, FlowFile flowFile) { protected String getAwsTaskId(final ProcessContext context, final TextractResponse textractResponse, final FlowFile flowFile) {
final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue()); final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (textractType) { return switch (textractType) {
case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResult) amazonWebServiceResult).getJobId(); case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResponse) textractResponse).jobId();
case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResult) amazonWebServiceResult).getJobId(); case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResponse) textractResponse).jobId();
case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResult) amazonWebServiceResult).getJobId(); case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResponse) textractResponse).jobId();
}; };
} }
@Override
protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, FlowFile flowFile, final TextractResponse response) {
flowFile = super.postProcessFlowFile(context, session, flowFile, response);
return session.putAttribute(flowFile, TEXTRACT_TYPE_ATTRIBUTE, context.getProperty(TEXTRACT_TYPE).getValue());
}
} }

View File

@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.transcribe; package org.apache.nifi.processors.aws.ml.transcribe;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.model.ThrottlingException;
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
import org.apache.nifi.annotation.behavior.WritesAttribute; import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
@ -35,7 +26,13 @@ import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
import software.amazon.awssdk.services.transcribe.model.LimitExceededException;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
@CapabilityDescription("Retrieves the current status of an AWS Transcribe job.") @CapabilityDescription("Retrieves the current status of an AWS Transcribe job.")
@ -43,55 +40,50 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
@WritesAttributes({ @WritesAttributes({
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.") @WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
}) })
public class GetAwsTranscribeJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTranscribeClient> { public class GetAwsTranscribeJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TranscribeClient, TranscribeClientBuilder> {
@Override @Override
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { return TranscribeClient.builder();
return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withEndpointConfiguration(endpointConfiguration)
.withClientConfiguration(config)
.build();
} }
@Override @Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get(); FlowFile flowFile = session.get();
if (flowFile == null) { if (flowFile == null) {
return; return;
} }
try { try {
GetTranscriptionJobResult job = getJob(context, flowFile); final GetTranscriptionJobResponse job = getJob(context, flowFile);
TranscriptionJobStatus jobStatus = TranscriptionJobStatus.fromValue(job.getTranscriptionJob().getTranscriptionJobStatus()); final TranscriptionJobStatus status = job.transcriptionJob().transcriptionJobStatus();
if (TranscriptionJobStatus.COMPLETED == jobStatus) { if (TranscriptionJobStatus.COMPLETED == status) {
writeToFlowFile(session, flowFile, job); flowFile = writeToFlowFile(session, flowFile, job);
session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.getTranscriptionJob().getTranscript().getTranscriptFileUri()); flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.transcriptionJob().transcript().transcriptFileUri());
session.transfer(flowFile, REL_SUCCESS); session.transfer(flowFile, REL_SUCCESS);
} else if (TranscriptionJobStatus.IN_PROGRESS == jobStatus) { } else if (TranscriptionJobStatus.IN_PROGRESS == status || TranscriptionJobStatus.QUEUED == status) {
session.transfer(flowFile, REL_RUNNING); session.transfer(flowFile, REL_RUNNING);
} else if (TranscriptionJobStatus.FAILED == jobStatus) { } else if (TranscriptionJobStatus.FAILED == status) {
final String failureReason = job.getTranscriptionJob().getFailureReason(); final String failureReason = job.transcriptionJob().failureReason();
session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason); flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
getLogger().error("Transcribe Task Failed {} Reason [{}]", flowFile, failureReason); getLogger().error("Transcribe Task Failed {} Reason [{}]", flowFile, failureReason);
} else {
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
throw new IllegalStateException("Unrecognized job status");
} }
} catch (ThrottlingException e) { } catch (final LimitExceededException e) {
getLogger().info("Request Rate Limit exceeded", e); getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED); session.transfer(flowFile, REL_THROTTLED);
return; } catch (final Exception e) {
} catch (Exception e) {
getLogger().warn("Failed to get Transcribe Job status", e); getLogger().warn("Failed to get Transcribe Job status", e);
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
return;
} }
} }
private GetTranscriptionJobResult getJob(ProcessContext context, FlowFile flowFile) { private GetTranscriptionJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue(); final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
GetTranscriptionJobRequest request = new GetTranscriptionJobRequest().withTranscriptionJobName(taskId); final GetTranscriptionJobRequest request = GetTranscriptionJobRequest.builder().transcriptionJobName(taskId).build();
return getClient(context).getTranscriptionJob(request); return getClient(context).getTranscriptionJob(request);
} }
} }

View File

@ -17,48 +17,45 @@
package org.apache.nifi.processors.aws.ml.transcribe; package org.apache.nifi.processors.aws.ml.transcribe;
import com.amazonaws.ClientConfiguration; import org.apache.nifi.annotation.behavior.WritesAttribute;
import com.amazonaws.auth.AWSCredentialsProvider; import org.apache.nifi.annotation.behavior.WritesAttributes;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
import com.amazonaws.services.transcribe.model.StartTranscriptionJobRequest;
import com.amazonaws.services.transcribe.model.StartTranscriptionJobResult;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
@CapabilityDescription("Trigger a AWS Transcribe job. It should be followed by GetAwsTranscribeStatus processor in order to monitor job status.") @CapabilityDescription("Trigger a AWS Transcribe job. It should be followed by GetAwsTranscribeStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranscribeJobStatus")
})
@SeeAlso({GetAwsTranscribeJobStatus.class}) @SeeAlso({GetAwsTranscribeJobStatus.class})
public class StartAwsTranscribeJob extends AwsMachineLearningJobStarter<AmazonTranscribeClient, StartTranscriptionJobRequest, StartTranscriptionJobResult> { public class StartAwsTranscribeJob extends AbstractAwsMachineLearningJobStarter<
StartTranscriptionJobRequest, StartTranscriptionJobRequest.Builder, StartTranscriptionJobResponse, TranscribeClient, TranscribeClientBuilder> {
@Override @Override
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { return TranscribeClient.builder();
return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.withCredentials(credentialsProvider)
.build();
} }
@Override @Override
protected StartTranscriptionJobResult sendRequest(StartTranscriptionJobRequest request, ProcessContext context, FlowFile flowFile) { protected StartTranscriptionJobResponse sendRequest(final StartTranscriptionJobRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startTranscriptionJob(request); return getClient(context).startTranscriptionJob(request);
} }
@Override @Override
protected Class<? extends StartTranscriptionJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) { protected Class<? extends StartTranscriptionJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
return StartTranscriptionJobRequest.class; return StartTranscriptionJobRequest.serializableBuilderClass();
} }
@Override @Override
protected String getAwsTaskId(ProcessContext context, StartTranscriptionJobResult startTranscriptionJobResult, FlowFile flowFile) { protected String getAwsTaskId(final ProcessContext context, final StartTranscriptionJobResponse startTranscriptionJobResponse, final FlowFile flowFile) {
return startTranscriptionJobResult.getTranscriptionJob().getTranscriptionJobName(); return startTranscriptionJobResponse.transcriptionJob().transcriptionJobName();
} }
} }

View File

@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.translate; package org.apache.nifi.processors.aws.ml.translate;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.model.ThrottlingException;
import com.amazonaws.services.translate.AmazonTranslateClient;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
import com.amazonaws.services.translate.model.JobStatus;
import org.apache.nifi.annotation.behavior.WritesAttribute; import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes; import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
@ -34,8 +25,15 @@ import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.TranslateClientBuilder;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
import software.amazon.awssdk.services.translate.model.JobStatus;
import software.amazon.awssdk.services.translate.model.LimitExceededException;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
@CapabilityDescription("Retrieves the current status of an AWS Translate job.") @CapabilityDescription("Retrieves the current status of an AWS Translate job.")
@ -43,54 +41,66 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
@WritesAttributes({ @WritesAttributes({
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.") @WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
}) })
public class GetAwsTranslateJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTranslateClient> { public class GetAwsTranslateJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TranslateClient, TranslateClientBuilder> {
@Override @Override
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { return TranslateClient.builder();
return (AmazonTranslateClient) AmazonTranslateClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
} }
@Override @Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException { public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get(); FlowFile flowFile = session.get();
if (flowFile == null) { if (flowFile == null) {
return; return;
} }
String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
try { try {
DescribeTextTranslationJobResult describeTextTranslationJobResult = getStatusString(context, awsTaskId); final DescribeTextTranslationJobResponse job = getJob(context, flowFile);
JobStatus status = JobStatus.fromValue(describeTextTranslationJobResult.getTextTranslationJobProperties().getJobStatus()); final JobStatus status = job.textTranslationJobProperties().jobStatus();
if (status == JobStatus.IN_PROGRESS || status == JobStatus.SUBMITTED) { flowFile = writeToFlowFile(session, flowFile, job);
writeToFlowFile(session, flowFile, describeTextTranslationJobResult); final Relationship transferRelationship;
session.penalize(flowFile); String failureReason = null;
session.transfer(flowFile, REL_RUNNING); switch (status) {
} else if (status == JobStatus.COMPLETED) { case IN_PROGRESS:
session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, describeTextTranslationJobResult.getTextTranslationJobProperties().getOutputDataConfig().getS3Uri()); case SUBMITTED:
writeToFlowFile(session, flowFile, describeTextTranslationJobResult); case STOP_REQUESTED:
session.transfer(flowFile, REL_SUCCESS); flowFile = session.penalize(flowFile);
} else if (status == JobStatus.FAILED || status == JobStatus.COMPLETED_WITH_ERROR) { transferRelationship = REL_RUNNING;
writeToFlowFile(session, flowFile, describeTextTranslationJobResult); break;
session.transfer(flowFile, REL_FAILURE); case COMPLETED:
flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.textTranslationJobProperties().outputDataConfig().s3Uri());
transferRelationship = REL_SUCCESS;
break;
case FAILED:
case COMPLETED_WITH_ERROR:
failureReason = job.textTranslationJobProperties().message();
transferRelationship = REL_FAILURE;
break;
case STOPPED:
failureReason = String.format("Job [%s] is stopped", job.textTranslationJobProperties().jobId());
transferRelationship = REL_FAILURE;
break;
default:
failureReason = "Unknown Job Status";
transferRelationship = REL_FAILURE;
} }
} catch (ThrottlingException e) { if (failureReason != null) {
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
}
session.transfer(flowFile, transferRelationship);
} catch (final LimitExceededException e) {
getLogger().info("Request Rate Limit exceeded", e); getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED); session.transfer(flowFile, REL_THROTTLED);
} catch (Exception e) { } catch (final Exception e) {
getLogger().warn("Failed to get Polly Job status", e); getLogger().warn("Failed to get Translate Job status", e);
session.transfer(flowFile, REL_FAILURE); session.transfer(flowFile, REL_FAILURE);
} }
} }
private DescribeTextTranslationJobResult getStatusString(ProcessContext context, String awsTaskId) { private DescribeTextTranslationJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
DescribeTextTranslationJobRequest request = new DescribeTextTranslationJobRequest().withJobId(awsTaskId); final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
DescribeTextTranslationJobResult translationJobsResult = getClient(context).describeTextTranslationJob(request); final DescribeTextTranslationJobRequest request = DescribeTextTranslationJobRequest.builder().jobId(taskId).build();
return translationJobsResult; return getClient(context).describeTextTranslationJob(request);
} }
} }

View File

@ -17,47 +17,44 @@
package org.apache.nifi.processors.aws.ml.translate; package org.apache.nifi.processors.aws.ml.translate;
import com.amazonaws.ClientConfiguration; import org.apache.nifi.annotation.behavior.WritesAttribute;
import com.amazonaws.auth.AWSCredentialsProvider; import org.apache.nifi.annotation.behavior.WritesAttributes;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.translate.AmazonTranslateClient;
import com.amazonaws.services.translate.model.StartTextTranslationJobRequest;
import com.amazonaws.services.translate.model.StartTextTranslationJobResult;
import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso; import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags; import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter; import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.TranslateClientBuilder;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"}) @Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
@CapabilityDescription("Trigger a AWS Translate job. It should be followed by GetAwsTranslateJobStatus processor in order to monitor job status.") @CapabilityDescription("Trigger a AWS Translate job. It should be followed by GetAwsTranslateJobStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranslateJobStatus")
})
@SeeAlso({GetAwsTranslateJobStatus.class}) @SeeAlso({GetAwsTranslateJobStatus.class})
public class StartAwsTranslateJob extends AwsMachineLearningJobStarter<AmazonTranslateClient, StartTextTranslationJobRequest, StartTextTranslationJobResult> { public class StartAwsTranslateJob extends AbstractAwsMachineLearningJobStarter<
StartTextTranslationJobRequest, StartTextTranslationJobRequest.Builder, StartTextTranslationJobResponse, TranslateClient, TranslateClientBuilder> {
@Override @Override
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) { return TranslateClient.builder();
return (AmazonTranslateClient) AmazonTranslateClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
} }
@Override @Override
protected StartTextTranslationJobResult sendRequest(StartTextTranslationJobRequest request, ProcessContext context, FlowFile flowFile) { protected StartTextTranslationJobResponse sendRequest(final StartTextTranslationJobRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startTextTranslationJob(request); return getClient(context).startTextTranslationJob(request);
} }
@Override @Override
protected Class<StartTextTranslationJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) { protected Class<? extends StartTextTranslationJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
return StartTextTranslationJobRequest.class; return StartTextTranslationJobRequest.serializableBuilderClass();
} }
protected String getAwsTaskId(ProcessContext context, StartTextTranslationJobResult startTextTranslationJobResult, FlowFile flowFile) { protected String getAwsTaskId(final ProcessContext context, final StartTextTranslationJobResponse startTextTranslationJobResponse, final FlowFile flowFile) {
return startTextTranslationJobResult.getJobId(); return startTextTranslationJobResponse.jobId();
} }
} }

View File

@ -17,18 +17,10 @@
package org.apache.nifi.processors.aws.ml.polly; package org.apache.nifi.processors.aws.ml.polly;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.polly.AmazonPollyClient;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
import com.amazonaws.services.polly.model.SynthesisTask;
import com.amazonaws.services.polly.model.TaskStatus;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners; import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -38,16 +30,20 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
import software.amazon.awssdk.services.polly.model.SynthesisTask;
import software.amazon.awssdk.services.polly.model.TaskStatus;
import java.util.Collections; import java.util.Collections;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_ORIGINAL; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -57,73 +53,84 @@ public class GetAwsPollyStatusTest {
private static final String PLACEHOLDER_CONTENT = "content"; private static final String PLACEHOLDER_CONTENT = "content";
private TestRunner runner; private TestRunner runner;
@Mock @Mock
private AmazonPollyClient mockPollyClient; private PollyClient mockPollyClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider; private GetAwsPollyJobStatus processor;
@Captor @Captor
private ArgumentCaptor<GetSpeechSynthesisTaskRequest> requestCaptor; private ArgumentCaptor<GetSpeechSynthesisTaskRequest> requestCaptor;
private TestRunner createRunner(final GetAwsPollyJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach @BeforeEach
public void setUp() throws InitializationException { public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider"); processor = new GetAwsPollyJobStatus() {
final GetAwsPollyJobStatus mockGetAwsPollyStatus = new GetAwsPollyJobStatus() {
@Override @Override
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, public PollyClient getClient(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return mockPollyClient; return mockPollyClient;
} }
}; };
runner = TestRunners.newTestRunner(mockGetAwsPollyStatus); runner = createRunner(processor);
runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
} }
@Test @Test
public void testPollyTaskInProgress() { public void testPollyTaskInProgress() {
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult(); GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID) .synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).taskStatus(TaskStatus.IN_PROGRESS).build())
.withTaskStatus(TaskStatus.InProgress); .build();
taskResult.setSynthesisTask(task); when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run(); runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING); runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId()); assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
} }
@Test @Test
public void testPollyTaskCompleted() { public void testPollyTaskCompleted() {
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult(); final String uri = "https://s3.us-west2.amazonaws.com/bucket/object";
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID) final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
.withTaskStatus(TaskStatus.Completed) .synthesisTask(SynthesisTask.builder()
.withOutputUri("outputLocationPath"); .taskId(TEST_TASK_ID)
taskResult.setSynthesisTask(task); .taskStatus(TaskStatus.COMPLETED)
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult); .outputUri(uri).build())
.build();
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run(); runner.run();
runner.assertTransferCount(REL_SUCCESS, 1); runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1); runner.assertTransferCount(REL_ORIGINAL, 1);
runner.assertAllFlowFilesContainAttribute(REL_SUCCESS, AWS_TASK_OUTPUT_LOCATION); runner.assertAllFlowFilesContainAttribute(REL_SUCCESS, AWS_TASK_OUTPUT_LOCATION);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId()); assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
}
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
assertEquals(uri, flowFile.getAttribute(GetAwsPollyJobStatus.AWS_TASK_OUTPUT_LOCATION));
assertEquals("bucket", flowFile.getAttribute("PollyS3OutputBucket"));
assertEquals("object", flowFile.getAttribute("filename"));
}
@Test @Test
public void testPollyTaskFailed() { public void testPollyTaskFailed() {
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult(); final String failureReason = "reasonOfFailure";
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID) final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
.withTaskStatus(TaskStatus.Failed) .synthesisTask(SynthesisTask.builder()
.withTaskStatusReason("reasonOfFailure"); .taskId(TEST_TASK_ID)
taskResult.setSynthesisTask(task); .taskStatus(TaskStatus.FAILED)
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult); .taskStatusReason(failureReason).build())
.build();
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run(); runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE); runner.assertAllFlowFilesTransferred(REL_FAILURE);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId()); assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
assertEquals(failureReason, flowFile.getAttribute(GetAwsPollyJobStatus.FAILURE_REASON_ATTRIBUTE));
} }
} }

View File

@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.polly;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.model.Engine;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
import software.amazon.awssdk.services.polly.model.SynthesisTask;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsPollyJobTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private PollyClient mockPollyClient;
private StartAwsPollyJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartSpeechSynthesisTaskRequest> requestCaptor;
private TestRunner createRunner(final StartAwsPollyJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsPollyJob() {
@Override
public PollyClient getClient(ProcessContext context) {
return mockPollyClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
.engine(Engine.NEURAL)
.text("Text")
.build();
final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
.build();
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
assertEquals("Text", requestCaptor.getValue().text());
assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
}
@Test
public void testSuccessfulAttribute() throws JsonProcessingException {
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
.engine(Engine.NEURAL)
.text("Text")
.build();
final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
.build();
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsPollyJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
assertEquals("Text", requestCaptor.getValue().text());
assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
}
@Test
public void testInvalidJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testServiceFailure() throws JsonProcessingException {
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
.engine(Engine.NEURAL)
.text("Text")
.build();
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartSpeechSynthesisTaskResponse deserialize(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartSpeechSynthesisTaskResponse.serializableBuilderClass()).build();
}
private String serialize(final StartSpeechSynthesisTaskRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -17,17 +17,10 @@
package org.apache.nifi.processors.aws.ml.textract; package org.apache.nifi.processors.aws.ml.textract;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.AmazonTextractClient;
import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
import com.amazonaws.services.textract.model.GetDocumentAnalysisResult;
import com.amazonaws.services.textract.model.JobStatus;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners; import org.apache.nifi.util.TestRunners;
@ -38,13 +31,20 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisResponse;
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionResponse;
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisResponse;
import software.amazon.awssdk.services.textract.model.JobStatus;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_THROTTLED;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -53,64 +53,144 @@ public class GetAwsTextractJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId"; private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner; private TestRunner runner;
@Mock @Mock
private AmazonTextractClient mockTextractClient; private TextractClient mockTextractClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider; private GetAwsTextractJobStatus processor;
@Captor @Captor
private ArgumentCaptor<GetDocumentAnalysisRequest> requestCaptor; private ArgumentCaptor<GetDocumentAnalysisRequest> documentAnalysisCaptor;
@Captor
private ArgumentCaptor<GetExpenseAnalysisRequest> expenseAnalysisRequestCaptor;
@Captor
private ArgumentCaptor<GetDocumentTextDetectionRequest> documentTextDetectionCaptor;
private TestRunner createRunner(final GetAwsTextractJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach @BeforeEach
public void setUp() throws InitializationException { public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider"); processor = new GetAwsTextractJobStatus() {
final GetAwsTextractJobStatus awsTextractJobStatusGetter = new GetAwsTextractJobStatus() {
@Override @Override
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, public TextractClient getClient(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return mockTextractClient; return mockTextractClient;
} }
}; };
runner = TestRunners.newTestRunner(awsTextractJobStatusGetter); runner = createRunner(processor);
runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
} }
@Test @Test
public void testTextractDocAnalysisTaskInProgress() { public void testTextractDocAnalysisTaskInProgress() {
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult() testTextractDocAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
.withJobStatus(JobStatus.IN_PROGRESS);
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
} }
@Test @Test
public void testTextractDocAnalysisTaskComplete() { public void testTextractDocAnalysisTaskComplete() {
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult() testTextractDocAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
.withJobStatus(JobStatus.SUCCEEDED);
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
runner.run();
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
} }
@Test @Test
public void testTextractDocAnalysisTaskFailed() { public void testTextractDocAnalysisTaskFailed() {
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult() testTextractDocAnalysis(JobStatus.FAILED, REL_FAILURE);
.withJobStatus(JobStatus.FAILED); }
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID, @Test
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.type)); public void testTextractDocAnalysisTaskPartialSuccess() {
testTextractDocAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
}
@Test
public void testTextractDocAnalysisTaskUnkownStatus() {
testTextractDocAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
}
private void testTextractDocAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
final GetDocumentAnalysisResponse response = GetDocumentAnalysisResponse.builder()
.jobStatus(jobStatus).build();
when(mockTextractClient.getDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
runner.enqueue("content", ImmutableMap.of(
TASK_ID.getName(), TEST_TASK_ID,
StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE, TextractType.DOCUMENT_ANALYSIS.getType()));
runner.run(); runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE); runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); assertEquals(TEST_TASK_ID, documentAnalysisCaptor.getValue().jobId());
}
@Test
public void testTextractExpenseAnalysisTaskInProgress() {
testTextractExpenseAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
}
@Test
public void testTextractExpenseAnalysisTaskComplete() {
testTextractExpenseAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
}
@Test
public void testTextractExpenseAnalysisTaskFailed() {
testTextractExpenseAnalysis(JobStatus.FAILED, REL_FAILURE);
}
@Test
public void testTextractExpenseAnalysisTaskPartialSuccess() {
testTextractExpenseAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
}
@Test
public void testTextractExpenseAnalysisTaskUnkownStatus() {
testTextractExpenseAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
}
private void testTextractExpenseAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.EXPENSE_ANALYSIS.getType());
final GetExpenseAnalysisResponse response = GetExpenseAnalysisResponse.builder()
.jobStatus(jobStatus).build();
when(mockTextractClient.getExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, expenseAnalysisRequestCaptor.getValue().jobId());
}
@Test
public void testTextractDocumentTextDetectionTaskInProgress() {
testTextractDocumentTextDetection(JobStatus.IN_PROGRESS, REL_RUNNING);
}
@Test
public void testTextractDocumentTextDetectionTaskComplete() {
testTextractDocumentTextDetection(JobStatus.SUCCEEDED, REL_SUCCESS);
}
@Test
public void testTextractDocumentTextDetectionTaskFailed() {
testTextractDocumentTextDetection(JobStatus.FAILED, REL_FAILURE);
}
@Test
public void testTextractDocumentTextDetectionTaskPartialSuccess() {
testTextractDocumentTextDetection(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
}
@Test
public void testTextractDocumentTextDetectionTaskUnkownStatus() {
testTextractDocumentTextDetection(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
}
private void testTextractDocumentTextDetection(final JobStatus jobStatus, final Relationship expectedRelationship) {
runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.DOCUMENT_TEXT_DETECTION.getType());
final GetDocumentTextDetectionResponse response = GetDocumentTextDetectionResponse.builder()
.jobStatus(jobStatus).build();
when(mockTextractClient.getDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, documentTextDetectionCaptor.getValue().jobId());
} }
} }

View File

@ -0,0 +1,342 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.textract;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_TEXT_DETECTION;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.EXPENSE_ANALYSIS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsTextractJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private TextractClient mockTextractClient;
private StartAwsTextractJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartDocumentAnalysisRequest> documentAnalysisCaptor;
@Captor
private ArgumentCaptor<StartExpenseAnalysisRequest> expenseAnalysisRequestCaptor;
@Captor
private ArgumentCaptor<StartDocumentTextDetectionRequest> documentTextDetectionCaptor;
private TestRunner createRunner(final StartAwsTextractJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsTextractJob() {
@Override
public TextractClient getClient(ProcessContext context) {
return mockTextractClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulDocumentAnalysisFlowfileContent() throws JsonProcessingException {
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulDocumentAnalysisAttribute() throws JsonProcessingException {
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidDocumentAnalysisJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testDocumentAnalysisServiceFailure() throws JsonProcessingException {
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
.jobTag("Tag")
.build();
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testSuccessfulExpenseAnalysisFlowfileContent() throws JsonProcessingException {
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulExpenseAnalysisAttribute() throws JsonProcessingException {
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidExpenseAnalysisJson() {
final String requestJson = "invalid";
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testExpenseAnalysisServiceFailure() throws JsonProcessingException {
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
.jobTag("Tag")
.build();
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testSuccessfulDocumentTextDetectionFlowfileContent() throws JsonProcessingException {
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulDocumentTextDetectionAttribute() throws JsonProcessingException {
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidDocumentTextDetectionJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testDocumentTextDetectionServiceFailure() throws JsonProcessingException {
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
.jobTag("Tag")
.build();
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartDocumentTextDetectionResponse deserializeDTDRequest(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartDocumentTextDetectionResponse.serializableBuilderClass()).build();
}
private StartDocumentAnalysisResponse deserializeDARequest(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartDocumentAnalysisResponse.serializableBuilderClass()).build();
}
private StartExpenseAnalysisResponse deserializeEARequest(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartExpenseAnalysisResponse.serializableBuilderClass()).build();
}
private String serialize(final StartDocumentAnalysisRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
private String serialize(final StartExpenseAnalysisRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
private String serialize(final StartDocumentTextDetectionRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -17,19 +17,11 @@
package org.apache.nifi.processors.aws.ml.transcribe; package org.apache.nifi.processors.aws.ml.transcribe;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
import com.amazonaws.services.transcribe.model.Transcript;
import com.amazonaws.services.transcribe.model.TranscriptionJob;
import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners; import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -39,93 +31,118 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
import software.amazon.awssdk.services.transcribe.model.Transcript;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
import java.util.Collections; import java.util.Collections;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
public class GetAwsTranscribeJobStatusTest { public class GetAwsTranscribeJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId"; private static final String TEST_TASK_ID = "testJobId";
private static final String AWS_CREDENTIAL_PROVIDER_NAME = "awsCredetialProvider";
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath"; private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
private static final String REASON_OF_FAILURE = "reasonOfFailure"; private static final String REASON_OF_FAILURE = "reasonOfFailure";
private static final String CONTENT_STRING = "content"; private static final String CONTENT_STRING = "content";
private TestRunner runner; private TestRunner runner;
@Mock @Mock
private AmazonTranscribeClient mockTranscribeClient; private TranscribeClient mockTranscribeClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider; private GetAwsTranscribeJobStatus processor;
@Captor @Captor
private ArgumentCaptor<GetTranscriptionJobRequest> requestCaptor; private ArgumentCaptor<GetTranscriptionJobRequest> requestCaptor;
private TestRunner createRunner(final GetAwsTranscribeJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach @BeforeEach
public void setUp() throws InitializationException { public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIAL_PROVIDER_NAME); processor = new GetAwsTranscribeJobStatus() {
final GetAwsTranscribeJobStatus mockPollyFetcher = new GetAwsTranscribeJobStatus() {
@Override @Override
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, public TranscribeClient getClient(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return mockTranscribeClient; return mockTranscribeClient;
} }
}; };
runner = TestRunners.newTestRunner(mockPollyFetcher); runner = createRunner(processor);
runner.addControllerService(AWS_CREDENTIAL_PROVIDER_NAME, mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIAL_PROVIDER_NAME);
} }
@Test @Test
public void testTranscribeTaskInProgress() { public void testTranscribeJobInProgress() {
TranscriptionJob task = new TranscriptionJob() final TranscriptionJob job = TranscriptionJob.builder()
.withTranscriptionJobName(TEST_TASK_ID) .transcriptionJobName(TEST_TASK_ID)
.withTranscriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS); .transcriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS)
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task); .build();
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult); testTranscribeJob(job, REL_RUNNING);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
} }
@Test @Test
public void testTranscribeTaskCompleted() { public void testTranscribeJobQueued() {
TranscriptionJob task = new TranscriptionJob() final TranscriptionJob job = TranscriptionJob.builder()
.withTranscriptionJobName(TEST_TASK_ID) .transcriptionJobName(TEST_TASK_ID)
.withTranscript(new Transcript().withTranscriptFileUri(OUTPUT_LOCATION_PATH)) .transcriptionJobStatus(TranscriptionJobStatus.QUEUED)
.withTranscriptionJobStatus(TranscriptionJobStatus.COMPLETED); .build();
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task); testTranscribeJob(job, REL_RUNNING);
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
} }
@Test
public void testTranscribeJobCompleted() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.transcript(Transcript.builder().transcriptFileUri(OUTPUT_LOCATION_PATH).build())
.transcriptionJobStatus(TranscriptionJobStatus.COMPLETED)
.build();
testTranscribeJob(job, REL_SUCCESS);
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
}
@Test @Test
public void testPollyTaskFailed() { public void testTranscribeJobFailed() {
TranscriptionJob task = new TranscriptionJob() final TranscriptionJob job = TranscriptionJob.builder()
.withTranscriptionJobName(TEST_TASK_ID) .transcriptionJobName(TEST_TASK_ID)
.withFailureReason(REASON_OF_FAILURE) .failureReason(REASON_OF_FAILURE)
.withTranscriptionJobStatus(TranscriptionJobStatus.FAILED); .transcriptionJobStatus(TranscriptionJobStatus.FAILED)
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task); .build();
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult); testTranscribeJob(job, REL_FAILURE);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE);
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE); runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName()); final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
assertEquals(REASON_OF_FAILURE, flowFile.getAttribute(FAILURE_REASON_ATTRIBUTE));
}
@Test
public void testTranscribeJobUnrecognized() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.failureReason(REASON_OF_FAILURE)
.transcriptionJobStatus(TranscriptionJobStatus.UNKNOWN_TO_SDK_VERSION)
.build();
testTranscribeJob(job, REL_FAILURE);
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
}
private void testTranscribeJob(final TranscriptionJob job, final Relationship expectedRelationship) {
final GetTranscriptionJobResponse response = GetTranscriptionJobResponse.builder().transcriptionJob(job).build();
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(response);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().transcriptionJobName());
} }
} }

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.transcribe;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsTranscribeJobTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private TranscribeClient mockTranscribeClient;
private StartAwsTranscribeJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartTranscriptionJobRequest> requestCaptor;
private TestRunner createRunner(final StartAwsTranscribeJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsTranscribeJob() {
@Override
public TranscribeClient getClient(ProcessContext context) {
return mockTranscribeClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
.transcriptionJobName("Job")
.build();
final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
.transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
.build();
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
}
@Test
public void testSuccessfulAttribute() throws JsonProcessingException {
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
.transcriptionJobName("Job")
.build();
final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
.transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
.build();
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTranscribeJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
}
@Test
public void testInvalidJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testServiceFailure() throws JsonProcessingException {
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
.transcriptionJobName("Job")
.build();
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartTranscriptionJobResponse deserialize(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartTranscriptionJobResponse.serializableBuilderClass()).build();
}
private String serialize(final StartTranscriptionJobRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -17,19 +17,11 @@
package org.apache.nifi.processors.aws.ml.translate; package org.apache.nifi.processors.aws.ml.translate;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.translate.AmazonTranslateClient;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
import com.amazonaws.services.translate.model.JobStatus;
import com.amazonaws.services.translate.model.OutputDataConfig;
import com.amazonaws.services.translate.model.TextTranslationJobProperties;
import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService; import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException; import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners; import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -39,91 +31,132 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor; import org.mockito.Captor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
import software.amazon.awssdk.services.translate.model.JobStatus;
import software.amazon.awssdk.services.translate.model.OutputDataConfig;
import software.amazon.awssdk.services.translate.model.TextTranslationJobProperties;
import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS; import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
public class GetAwsTranslateJobStatusTest { public class GetAwsTranslateJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId"; private static final String TEST_TASK_ID = "testJobId";
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
private static final String REASON_OF_FAILURE = "reasonOfFailure";
private static final String CONTENT_STRING = "content"; private static final String CONTENT_STRING = "content";
private static final String AWS_CREDENTIALS_PROVIDER_NAME = "awsCredetialProvider";
private static final String OUTPUT_LOCATION_PATH = "outputLocation";
private TestRunner runner; private TestRunner runner;
@Mock @Mock
private AmazonTranslateClient mockTranslateClient; private TranslateClient mockTranslateClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider; private GetAwsTranslateJobStatus processor;
@Captor @Captor
private ArgumentCaptor<DescribeTextTranslationJobRequest> requestCaptor; private ArgumentCaptor<DescribeTextTranslationJobRequest> requestCaptor;
private TestRunner createRunner(final GetAwsTranslateJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach @BeforeEach
public void setUp() throws InitializationException { public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIALS_PROVIDER_NAME); processor = new GetAwsTranslateJobStatus() {
final GetAwsTranslateJobStatus mockPollyFetcher = new GetAwsTranslateJobStatus() {
@Override @Override
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config, public TranslateClient getClient(final ProcessContext context) {
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return mockTranslateClient; return mockTranslateClient;
} }
}; };
runner = TestRunners.newTestRunner(mockPollyFetcher); runner = createRunner(processor);
runner.addControllerService(AWS_CREDENTIALS_PROVIDER_NAME, mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIALS_PROVIDER_NAME);
} }
@Test @Test
public void testTranscribeTaskInProgress() { public void testTranslateJobInProgress() {
TextTranslationJobProperties task = new TextTranslationJobProperties() final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.withJobId(TEST_TASK_ID) .jobId(TEST_TASK_ID)
.withJobStatus(JobStatus.IN_PROGRESS); .jobStatus(JobStatus.IN_PROGRESS)
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task); .build();
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult); testTranslateJob(job, REL_RUNNING);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
} }
@Test @Test
public void testTranscribeTaskCompleted() { public void testTranslateSubmitted() {
TextTranslationJobProperties task = new TextTranslationJobProperties() final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.withJobId(TEST_TASK_ID) .jobId(TEST_TASK_ID)
.withOutputDataConfig(new OutputDataConfig().withS3Uri(OUTPUT_LOCATION_PATH)) .jobStatus(JobStatus.SUBMITTED)
.withJobStatus(JobStatus.COMPLETED); .build();
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task); testTranslateJob(job, REL_RUNNING);
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult); }
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_SUCCESS); @Test
public void testTranslateStopRequested() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobId(TEST_TASK_ID)
.jobStatus(JobStatus.STOP_REQUESTED)
.build();
testTranslateJob(job, REL_RUNNING);
}
@Test
public void testTranslateJobCompleted() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.outputDataConfig(OutputDataConfig.builder().s3Uri(OUTPUT_LOCATION_PATH).build())
.submittedTime(Instant.now())
.jobStatus(JobStatus.COMPLETED)
.build();
testTranslateJob(job, REL_SUCCESS);
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION); runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
} }
@Test @Test
public void testTranscribeTaskFailed() { public void testTranslateJobFailed() {
TextTranslationJobProperties task = new TextTranslationJobProperties() final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.withJobId(TEST_TASK_ID) .jobStatus(TEST_TASK_ID)
.withJobStatus(JobStatus.FAILED); .jobStatus(JobStatus.FAILED)
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task); .build();
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult); testTranslateJob(job, REL_FAILURE);
}
@Test
public void testTranslateJobStopped() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.jobStatus(JobStatus.STOPPED)
.build();
testTranslateJob(job, REL_FAILURE);
}
@Test
public void testTranslateJobUnrecognized() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.jobStatus(JobStatus.UNKNOWN_TO_SDK_VERSION)
.build();
testTranslateJob(job, REL_FAILURE);
}
private void testTranslateJob(final TextTranslationJobProperties job, final Relationship expectedRelationship) {
final DescribeTextTranslationJobResponse response = DescribeTextTranslationJobResponse.builder().textTranslationJobProperties(job).build();
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(response);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID)); runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run(); runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE); runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId()); assertEquals(TEST_TASK_ID, requestCaptor.getValue().jobId());
} }
} }

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.translate;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsTranslateJobTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private TranslateClient mockTranslateClient;
private StartAwsTranslateJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartTextTranslationJobRequest> requestCaptor;
private TestRunner createRunner(final StartAwsTranslateJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsTranslateJob() {
@Override
public TranslateClient getClient(ProcessContext context) {
return mockTranslateClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
.terminologyNames("Name")
.build();
final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulAttribute() throws JsonProcessingException {
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
.terminologyNames("Name")
.build();
final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTranslateJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testServiceFailure() throws JsonProcessingException {
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
.terminologyNames("Name")
.build();
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartTextTranslationJobResponse deserialize(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartTextTranslationJobResponse.serializableBuilderClass()).build();
}
private String serialize(final StartTextTranslationJobRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -66,7 +66,6 @@ public abstract class AbstractSQSIT {
queueUrl = response.queueUrl(); queueUrl = response.queueUrl();
} }
@AfterAll @AfterAll
public static void shutdown() { public static void shutdown() {
client.close(); client.close();