mirror of
https://github.com/apache/nifi.git
synced 2025-02-07 18:48:51 +00:00
NIFI-12263 Upgraded AWS Machine Learning processors to SDK 2
This closes #7953 Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
parent
c706877147
commit
77834c92df
@ -65,6 +65,22 @@
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>firehose</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>polly</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>textract</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>transcribe</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>translate</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.kinesis</groupId>
|
||||
<artifactId>amazon-kinesis-client</artifactId>
|
||||
@ -130,6 +146,10 @@
|
||||
<groupId>com.github.ben-manes.caffeine</groupId>
|
||||
<artifactId>caffeine</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
||||
</dependency>
|
||||
<!-- Version 2 of the AmazonS3EncryptionClient requires bouncy castle -->
|
||||
<dependency>
|
||||
<groupId>org.bouncycastle</groupId>
|
||||
|
@ -17,14 +17,6 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml;
|
||||
|
||||
import com.amazonaws.AmazonWebServiceClient;
|
||||
import com.amazonaws.AmazonWebServiceRequest;
|
||||
import com.amazonaws.AmazonWebServiceResult;
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.regions.Regions;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
@ -33,12 +25,18 @@ import org.apache.commons.io.IOUtils;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
import org.apache.nifi.expression.ExpressionLanguageScope;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.migration.PropertyConfiguration;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.processor.util.StandardValidators;
|
||||
import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor;
|
||||
import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
|
||||
import software.amazon.awssdk.awscore.AwsRequest;
|
||||
import software.amazon.awssdk.awscore.AwsResponse;
|
||||
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
|
||||
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
|
||||
import software.amazon.awssdk.core.SdkClient;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
@ -46,10 +44,15 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.apache.nifi.flowfile.attributes.CoreAttributes.MIME_TYPE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
|
||||
public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceClient, REQUEST extends AmazonWebServiceRequest, RESPONSE extends AmazonWebServiceResult>
|
||||
extends AbstractAWSCredentialsProviderProcessor<T> {
|
||||
public abstract class AbstractAwsMachineLearningJobStarter<
|
||||
Q extends AwsRequest,
|
||||
B extends AwsRequest.Builder,
|
||||
R extends AwsResponse,
|
||||
T extends SdkClient,
|
||||
U extends AwsSyncClientBuilder<U, T> & AwsClientBuilder<U, T>>
|
||||
extends AbstractAwsSyncProcessor<T, U> {
|
||||
public static final PropertyDescriptor JSON_PAYLOAD = new PropertyDescriptor.Builder()
|
||||
.name("json-payload")
|
||||
.displayName("JSON Payload")
|
||||
@ -62,18 +65,17 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
|
||||
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
|
||||
.required(true)
|
||||
.build();
|
||||
public static final PropertyDescriptor REGION = new PropertyDescriptor.Builder()
|
||||
.displayName("Region")
|
||||
.name("aws-region")
|
||||
.required(true)
|
||||
.allowableValues(getAvailableRegions())
|
||||
.defaultValue(createAllowableValue(Regions.DEFAULT_REGION).getValue())
|
||||
.build();
|
||||
public static final Relationship REL_ORIGINAL = new Relationship.Builder()
|
||||
.name("original")
|
||||
.description("Upon successful completion, the original FlowFile will be routed to this relationship.")
|
||||
.autoTerminateDefault(true)
|
||||
.build();
|
||||
|
||||
@Override
|
||||
public void migrateProperties(final PropertyConfiguration config) {
|
||||
config.renameProperty("aws-region", REGION.getName());
|
||||
}
|
||||
|
||||
protected static final List<PropertyDescriptor> PROPERTIES = List.of(
|
||||
MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE,
|
||||
REGION,
|
||||
@ -84,10 +86,9 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
|
||||
|
||||
private final static ObjectMapper MAPPER = JsonMapper.builder()
|
||||
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
|
||||
.findAndAddModules()
|
||||
.build();
|
||||
private static final Set<Relationship> relationships = Set.of(REL_ORIGINAL,
|
||||
REL_SUCCESS,
|
||||
REL_FAILURE);
|
||||
private static final Set<Relationship> relationships = Set.of(REL_ORIGINAL, REL_SUCCESS, REL_FAILURE);
|
||||
|
||||
@Override
|
||||
public Set<Relationship> getRelationships() {
|
||||
@ -105,14 +106,14 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
|
||||
if (flowFile == null && !context.getProperty(JSON_PAYLOAD).isSet()) {
|
||||
return;
|
||||
}
|
||||
final RESPONSE response;
|
||||
final R response;
|
||||
FlowFile childFlowFile;
|
||||
try {
|
||||
response = sendRequest(buildRequest(session, context, flowFile), context, flowFile);
|
||||
childFlowFile = writeToFlowFile(session, flowFile, response);
|
||||
postProcessFlowFile(context, session, childFlowFile, response);
|
||||
childFlowFile = postProcessFlowFile(context, session, childFlowFile, response);
|
||||
session.transfer(childFlowFile, REL_SUCCESS);
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
if (flowFile != null) {
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
}
|
||||
@ -125,26 +126,21 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
|
||||
|
||||
}
|
||||
|
||||
protected void postProcessFlowFile(ProcessContext context, ProcessSession session, FlowFile flowFile, RESPONSE response) {
|
||||
protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, final FlowFile flowFile, final R response) {
|
||||
final String awsTaskId = getAwsTaskId(context, response, flowFile);
|
||||
flowFile = session.putAttribute(flowFile, TASK_ID.getName(), awsTaskId);
|
||||
flowFile = session.putAttribute(flowFile, MIME_TYPE.key(), "application/json");
|
||||
FlowFile processedFlowFile = session.putAttribute(flowFile, TASK_ID.getName(), awsTaskId);
|
||||
processedFlowFile = session.putAttribute(processedFlowFile, MIME_TYPE.key(), "application/json");
|
||||
getLogger().debug("AWS ML Task [{}] started", awsTaskId);
|
||||
return processedFlowFile;
|
||||
}
|
||||
|
||||
protected REQUEST buildRequest(ProcessSession session, ProcessContext context, FlowFile flowFile) throws JsonProcessingException {
|
||||
return MAPPER.readValue(getPayload(session, context, flowFile), getAwsRequestClass(context, flowFile));
|
||||
protected Q buildRequest(final ProcessSession session, final ProcessContext context, final FlowFile flowFile) throws JsonProcessingException {
|
||||
return (Q) MAPPER.readValue(getPayload(session, context, flowFile), getAwsRequestBuilderClass(context, flowFile)).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected T createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
protected FlowFile writeToFlowFile(ProcessSession session, FlowFile flowFile, RESPONSE response) {
|
||||
protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final R response) {
|
||||
FlowFile childFlowFile = flowFile == null ? session.create() : session.create(flowFile);
|
||||
childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response));
|
||||
childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
|
||||
return childFlowFile;
|
||||
}
|
||||
|
||||
@ -156,7 +152,7 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
|
||||
}
|
||||
}
|
||||
|
||||
private String getPayload(ProcessSession session, ProcessContext context, FlowFile flowFile) {
|
||||
private String getPayload(final ProcessSession session, final ProcessContext context, final FlowFile flowFile) {
|
||||
String payloadPropertyValue = context.getProperty(JSON_PAYLOAD).evaluateAttributeExpressions(flowFile).getValue();
|
||||
if (payloadPropertyValue == null) {
|
||||
payloadPropertyValue = readFlowFile(session, flowFile);
|
||||
@ -164,9 +160,9 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
|
||||
return payloadPropertyValue;
|
||||
}
|
||||
|
||||
abstract protected RESPONSE sendRequest(REQUEST request, ProcessContext context, FlowFile flowFile) throws JsonProcessingException;
|
||||
abstract protected R sendRequest(Q request, ProcessContext context, FlowFile flowFile) throws JsonProcessingException;
|
||||
|
||||
abstract protected Class<? extends REQUEST> getAwsRequestClass(ProcessContext context, FlowFile flowFile);
|
||||
abstract protected Class<? extends B> getAwsRequestBuilderClass(ProcessContext context, FlowFile flowFile);
|
||||
|
||||
abstract protected String getAwsTaskId(ProcessContext context, RESPONSE response, FlowFile flowFile);
|
||||
abstract protected String getAwsTaskId(ProcessContext context, R response, FlowFile flowFile);
|
||||
}
|
@ -17,33 +17,30 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml;
|
||||
|
||||
import com.amazonaws.AmazonWebServiceClient;
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.ResponseMetadata;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.http.SdkHttpMetadata;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.regions.Regions;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import com.fasterxml.jackson.databind.module.SimpleModule;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.migration.PropertyConfiguration;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processor.util.StandardValidators;
|
||||
import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor;
|
||||
import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
|
||||
import software.amazon.awssdk.awscore.AwsResponse;
|
||||
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
|
||||
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
|
||||
import software.amazon.awssdk.core.SdkClient;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.apache.nifi.expression.ExpressionLanguageScope.FLOWFILE_ATTRIBUTES;
|
||||
|
||||
public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebServiceClient>
|
||||
extends AbstractAWSCredentialsProviderProcessor<T> {
|
||||
public abstract class AbstractAwsMachineLearningJobStatusProcessor<
|
||||
T extends SdkClient,
|
||||
U extends AwsSyncClientBuilder<U, T> & AwsClientBuilder<U, T>>
|
||||
extends AbstractAwsSyncProcessor<T, U> {
|
||||
public static final String AWS_TASK_OUTPUT_LOCATION = "outputLocation";
|
||||
public static final PropertyDescriptor MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE =
|
||||
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
|
||||
@ -81,13 +78,6 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
|
||||
.description("The job failed, the original FlowFile will be routed to this relationship.")
|
||||
.autoTerminateDefault(true)
|
||||
.build();
|
||||
public static final PropertyDescriptor REGION = new PropertyDescriptor.Builder()
|
||||
.displayName("Region")
|
||||
.name("aws-region")
|
||||
.required(true)
|
||||
.allowableValues(getAvailableRegions())
|
||||
.defaultValue(createAllowableValue(Regions.DEFAULT_REGION).getValue())
|
||||
.build();
|
||||
public static final String FAILURE_REASON_ATTRIBUTE = "failure.reason";
|
||||
protected static final List<PropertyDescriptor> PROPERTIES = List.of(
|
||||
TASK_ID,
|
||||
@ -99,18 +89,14 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
|
||||
PROXY_CONFIGURATION_SERVICE);
|
||||
private static final ObjectMapper MAPPER = JsonMapper.builder()
|
||||
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
|
||||
.findAndAddModules()
|
||||
.build();
|
||||
|
||||
static {
|
||||
SimpleModule awsResponseModule = new SimpleModule();
|
||||
awsResponseModule.addDeserializer(ResponseMetadata.class, new AwsResponseMetadataDeserializer());
|
||||
SimpleModule sdkHttpModule = new SimpleModule();
|
||||
awsResponseModule.addDeserializer(SdkHttpMetadata.class, new SdkHttpMetadataDeserializer());
|
||||
MAPPER.registerModule(awsResponseModule);
|
||||
MAPPER.registerModule(sdkHttpModule);
|
||||
@Override
|
||||
public void migrateProperties(final PropertyConfiguration config) {
|
||||
config.renameProperty("aws-region", REGION.getName());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Set<Relationship> getRelationships() {
|
||||
return relationships;
|
||||
@ -129,13 +115,7 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
|
||||
return PROPERTIES;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected T createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
protected void writeToFlowFile(ProcessSession session, FlowFile flowFile, Object response) {
|
||||
session.write(flowFile, out -> MAPPER.writeValue(out, response));
|
||||
protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final AwsResponse response) {
|
||||
return session.write(flowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
|
||||
}
|
||||
}
|
@ -125,22 +125,6 @@
|
||||
<groupId>org.apache.nifi</groupId>
|
||||
<artifactId>nifi-schema-registry-service-api</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.amazonaws</groupId>
|
||||
<artifactId>aws-java-sdk-translate</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.amazonaws</groupId>
|
||||
<artifactId>aws-java-sdk-polly</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.amazonaws</groupId>
|
||||
<artifactId>aws-java-sdk-transcribe</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.amazonaws</groupId>
|
||||
<artifactId>aws-java-sdk-textract</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.bouncycastle</groupId>
|
||||
<artifactId>bcprov-jdk18on</artifactId>
|
||||
|
@ -17,15 +17,6 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.polly;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.polly.AmazonPollyClient;
|
||||
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
|
||||
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
|
||||
import com.amazonaws.services.polly.model.TaskStatus;
|
||||
import com.amazonaws.services.textract.model.ThrottlingException;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
@ -34,9 +25,17 @@ import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
|
||||
import software.amazon.awssdk.services.polly.PollyClient;
|
||||
import software.amazon.awssdk.services.polly.PollyClientBuilder;
|
||||
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
|
||||
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
|
||||
import software.amazon.awssdk.services.polly.model.TaskStatus;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@ -45,10 +44,10 @@ import java.util.regex.Pattern;
|
||||
@SeeAlso({StartAwsPollyJob.class})
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "PollyS3OutputBucket", description = "The bucket name where polly output will be located."),
|
||||
@WritesAttribute(attribute = "PollyS3OutputKey", description = "Object key of polly output."),
|
||||
@WritesAttribute(attribute = "filename", description = "Object key of polly output."),
|
||||
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
|
||||
})
|
||||
public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonPollyClient> {
|
||||
public class GetAwsPollyJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<PollyClient, PollyClientBuilder> {
|
||||
private static final String BUCKET = "bucket";
|
||||
private static final String KEY = "key";
|
||||
private static final Pattern S3_PATH = Pattern.compile("https://s3.*amazonaws.com/(?<" + BUCKET + ">[^/]+)/(?<" + KEY + ">.*)");
|
||||
@ -56,65 +55,66 @@ public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor<A
|
||||
private static final String AWS_S3_KEY = "filename";
|
||||
|
||||
@Override
|
||||
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonPollyClient) AmazonPollyClient.builder()
|
||||
.withCredentials(credentialsProvider)
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.withClientConfiguration(config)
|
||||
.build();
|
||||
public Set<Relationship> getRelationships() {
|
||||
final Set<Relationship> parentRelationships = new HashSet<>(super.getRelationships());
|
||||
parentRelationships.remove(REL_THROTTLED);
|
||||
return Set.copyOf(parentRelationships);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
|
||||
protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return PollyClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
|
||||
FlowFile flowFile = session.get();
|
||||
if (flowFile == null) {
|
||||
return;
|
||||
}
|
||||
GetSpeechSynthesisTaskResult speechSynthesisTask;
|
||||
final GetSpeechSynthesisTaskResponse speechSynthesisTask;
|
||||
try {
|
||||
speechSynthesisTask = getSynthesisTask(context, flowFile);
|
||||
} catch (ThrottlingException e) {
|
||||
getLogger().info("Request Rate Limit exceeded", e);
|
||||
session.transfer(flowFile, REL_THROTTLED);
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
getLogger().warn("Failed to get Polly Job status", e);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
return;
|
||||
}
|
||||
|
||||
TaskStatus taskStatus = TaskStatus.fromValue(speechSynthesisTask.getSynthesisTask().getTaskStatus());
|
||||
final TaskStatus taskStatus = speechSynthesisTask.synthesisTask().taskStatus();
|
||||
|
||||
if (taskStatus == TaskStatus.InProgress || taskStatus == TaskStatus.Scheduled) {
|
||||
session.penalize(flowFile);
|
||||
if (taskStatus == TaskStatus.IN_PROGRESS || taskStatus == TaskStatus.SCHEDULED) {
|
||||
flowFile = session.penalize(flowFile);
|
||||
session.transfer(flowFile, REL_RUNNING);
|
||||
} else if (taskStatus == TaskStatus.Completed) {
|
||||
String outputUri = speechSynthesisTask.getSynthesisTask().getOutputUri();
|
||||
} else if (taskStatus == TaskStatus.COMPLETED) {
|
||||
final String outputUri = speechSynthesisTask.synthesisTask().outputUri();
|
||||
|
||||
Matcher matcher = S3_PATH.matcher(outputUri);
|
||||
final Matcher matcher = S3_PATH.matcher(outputUri);
|
||||
if (matcher.find()) {
|
||||
session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
|
||||
session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
|
||||
flowFile = session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
|
||||
flowFile = session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
|
||||
}
|
||||
FlowFile childFlowFile = session.create(flowFile);
|
||||
writeToFlowFile(session, childFlowFile, speechSynthesisTask);
|
||||
childFlowFile = writeToFlowFile(session, childFlowFile, speechSynthesisTask);
|
||||
childFlowFile = session.putAttribute(childFlowFile, AWS_TASK_OUTPUT_LOCATION, outputUri);
|
||||
session.transfer(flowFile, REL_ORIGINAL);
|
||||
session.transfer(childFlowFile, REL_SUCCESS);
|
||||
getLogger().info("Amazon Polly Task Completed {}", flowFile);
|
||||
} else if (taskStatus == TaskStatus.Failed) {
|
||||
final String failureReason = speechSynthesisTask.getSynthesisTask().getTaskStatusReason();
|
||||
} else if (taskStatus == TaskStatus.FAILED) {
|
||||
final String failureReason = speechSynthesisTask.synthesisTask().taskStatusReason();
|
||||
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
getLogger().error("Amazon Polly Task Failed {} Reason [{}]", flowFile, failureReason);
|
||||
} else if (taskStatus == TaskStatus.UNKNOWN_TO_SDK_VERSION) {
|
||||
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
getLogger().error("Amazon Polly Task Failed {} Reason [Unrecognized job status]", flowFile);
|
||||
}
|
||||
}
|
||||
|
||||
private GetSpeechSynthesisTaskResult getSynthesisTask(ProcessContext context, FlowFile flowFile) {
|
||||
String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
GetSpeechSynthesisTaskRequest request = new GetSpeechSynthesisTaskRequest().withTaskId(taskId);
|
||||
private GetSpeechSynthesisTaskResponse getSynthesisTask(final ProcessContext context, final FlowFile flowFile) {
|
||||
final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
final GetSpeechSynthesisTaskRequest request = GetSpeechSynthesisTaskRequest.builder().taskId(taskId).build();
|
||||
return getClient(context).getSpeechSynthesisTask(request);
|
||||
}
|
||||
}
|
||||
|
@ -17,49 +17,45 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.polly;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.polly.AmazonPollyClient;
|
||||
import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskRequest;
|
||||
import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskResult;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.SeeAlso;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
|
||||
import software.amazon.awssdk.services.polly.PollyClient;
|
||||
import software.amazon.awssdk.services.polly.PollyClientBuilder;
|
||||
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
|
||||
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Polly"})
|
||||
@CapabilityDescription("Trigger a AWS Polly job. It should be followed by GetAwsPollyJobStatus processor in order to monitor job status.")
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsPollyJobStatus")
|
||||
})
|
||||
@SeeAlso({GetAwsPollyJobStatus.class})
|
||||
public class StartAwsPollyJob extends AwsMachineLearningJobStarter<AmazonPollyClient, StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskResult> {
|
||||
public class StartAwsPollyJob extends AbstractAwsMachineLearningJobStarter<
|
||||
StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskRequest.Builder, StartSpeechSynthesisTaskResponse, PollyClient, PollyClientBuilder> {
|
||||
|
||||
@Override
|
||||
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
|
||||
return (AmazonPollyClient) AmazonPollyClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withCredentials(credentialsProvider)
|
||||
.withClientConfiguration(config)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.build();
|
||||
protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return PollyClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected StartSpeechSynthesisTaskResult sendRequest(StartSpeechSynthesisTaskRequest request, ProcessContext context, FlowFile flowFile) {
|
||||
protected StartSpeechSynthesisTaskResponse sendRequest(final StartSpeechSynthesisTaskRequest request, final ProcessContext context, final FlowFile flowFile) {
|
||||
return getClient(context).startSpeechSynthesisTask(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<? extends StartSpeechSynthesisTaskRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
|
||||
return StartSpeechSynthesisTaskRequest.class;
|
||||
protected Class<? extends StartSpeechSynthesisTaskRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
|
||||
return StartSpeechSynthesisTaskRequest.serializableBuilderClass();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getAwsTaskId(ProcessContext context, StartSpeechSynthesisTaskResult startSpeechSynthesisTaskResult, FlowFile flowFile) {
|
||||
return startSpeechSynthesisTaskResult.getSynthesisTask().getTaskId();
|
||||
protected String getAwsTaskId(final ProcessContext context, final StartSpeechSynthesisTaskResponse startSpeechSynthesisTaskResponse, final FlowFile flowFile) {
|
||||
return startSpeechSynthesisTaskResponse.synthesisTask().taskId();
|
||||
}
|
||||
}
|
||||
|
@ -17,48 +17,61 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.textract;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.textract.AmazonTextractClient;
|
||||
import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
|
||||
import com.amazonaws.services.textract.model.GetDocumentTextDetectionRequest;
|
||||
import com.amazonaws.services.textract.model.GetExpenseAnalysisRequest;
|
||||
import com.amazonaws.services.textract.model.JobStatus;
|
||||
import com.amazonaws.services.textract.model.ThrottlingException;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.SeeAlso;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
import org.apache.nifi.components.ValidationContext;
|
||||
import org.apache.nifi.components.ValidationResult;
|
||||
import org.apache.nifi.components.Validator;
|
||||
import org.apache.nifi.expression.ExpressionLanguageScope;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.processor.util.StandardValidators;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
|
||||
import software.amazon.awssdk.services.textract.TextractClient;
|
||||
import software.amazon.awssdk.services.textract.TextractClientBuilder;
|
||||
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
|
||||
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.JobStatus;
|
||||
import software.amazon.awssdk.services.textract.model.TextractResponse;
|
||||
import software.amazon.awssdk.services.textract.model.ThrottlingException;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
|
||||
import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE;
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
|
||||
@CapabilityDescription("Retrieves the current status of an AWS Textract job.")
|
||||
@SeeAlso({StartAwsTextractJob.class})
|
||||
public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTextractClient> {
|
||||
public class GetAwsTextractJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TextractClient, TextractClientBuilder> {
|
||||
|
||||
public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
|
||||
@Override
|
||||
public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
|
||||
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
|
||||
return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
|
||||
} else if (TextractType.TEXTRACT_TYPES.contains(value)) {
|
||||
return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
|
||||
} else {
|
||||
return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
|
||||
.name("textract-type")
|
||||
.displayName("Textract Type")
|
||||
.required(true)
|
||||
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
|
||||
.allowableValues(TextractType.TEXTRACT_TYPES)
|
||||
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
|
||||
.defaultValue(DOCUMENT_ANALYSIS.getType())
|
||||
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
|
||||
.defaultValue(String.format("${%s}", TEXTRACT_TYPE_ATTRIBUTE))
|
||||
.addValidator(TEXTRACT_TYPE_VALIDATOR)
|
||||
.build();
|
||||
private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES =
|
||||
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
|
||||
@ -69,30 +82,24 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonTextractClient) AmazonTextractClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withClientConfiguration(config)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.withCredentials(credentialsProvider)
|
||||
.build();
|
||||
protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return TextractClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
|
||||
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
|
||||
FlowFile flowFile = session.get();
|
||||
if (flowFile == null) {
|
||||
return;
|
||||
}
|
||||
String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
|
||||
final String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
|
||||
|
||||
String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
final String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
try {
|
||||
JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
|
||||
final JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
|
||||
if (JobStatus.SUCCEEDED == jobStatus) {
|
||||
Object task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
|
||||
writeToFlowFile(session, flowFile, task);
|
||||
final TextractResponse task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
|
||||
flowFile = writeToFlowFile(session, flowFile, task);
|
||||
session.transfer(flowFile, REL_SUCCESS);
|
||||
} else if (JobStatus.IN_PROGRESS == jobStatus) {
|
||||
session.transfer(flowFile, REL_RUNNING);
|
||||
@ -101,29 +108,31 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
|
||||
} else if (JobStatus.FAILED == jobStatus) {
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
getLogger().error("Amazon Textract Task [{}] Failed", awsTaskId);
|
||||
} else {
|
||||
throw new IllegalStateException("Unrecognized job status");
|
||||
}
|
||||
} catch (ThrottlingException e) {
|
||||
} catch (final ThrottlingException e) {
|
||||
getLogger().info("Request Rate Limit exceeded", e);
|
||||
session.transfer(flowFile, REL_THROTTLED);
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
getLogger().warn("Failed to get Textract Job status", e);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
private Object getTask(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) {
|
||||
private TextractResponse getTask(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
|
||||
return switch (typeOfTextract) {
|
||||
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId));
|
||||
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId));
|
||||
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId));
|
||||
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build());
|
||||
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build());
|
||||
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build());
|
||||
};
|
||||
}
|
||||
|
||||
private JobStatus getTaskStatus(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) {
|
||||
private JobStatus getTaskStatus(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
|
||||
return switch (typeOfTextract) {
|
||||
case DOCUMENT_ANALYSIS -> JobStatus.fromValue(client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)).getJobStatus());
|
||||
case DOCUMENT_TEXT_DETECTION -> JobStatus.fromValue(client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)).getJobStatus());
|
||||
case EXPENSE_ANALYSIS -> JobStatus.fromValue(client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)).getJobStatus());
|
||||
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
|
||||
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build()).jobStatus();
|
||||
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -17,31 +17,26 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.textract;
|
||||
|
||||
import com.amazonaws.AmazonWebServiceRequest;
|
||||
import com.amazonaws.AmazonWebServiceResult;
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.textract.AmazonTextractClient;
|
||||
import com.amazonaws.services.textract.model.StartDocumentAnalysisRequest;
|
||||
import com.amazonaws.services.textract.model.StartDocumentAnalysisResult;
|
||||
import com.amazonaws.services.textract.model.StartDocumentTextDetectionRequest;
|
||||
import com.amazonaws.services.textract.model.StartDocumentTextDetectionResult;
|
||||
import com.amazonaws.services.textract.model.StartExpenseAnalysisRequest;
|
||||
import com.amazonaws.services.textract.model.StartExpenseAnalysisResult;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.SeeAlso;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.components.PropertyDescriptor;
|
||||
import org.apache.nifi.components.ValidationContext;
|
||||
import org.apache.nifi.components.ValidationResult;
|
||||
import org.apache.nifi.components.Validator;
|
||||
import org.apache.nifi.expression.ExpressionLanguageScope;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
|
||||
import software.amazon.awssdk.services.textract.TextractClient;
|
||||
import software.amazon.awssdk.services.textract.TextractClientBuilder;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
|
||||
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
|
||||
import software.amazon.awssdk.services.textract.model.TextractRequest;
|
||||
import software.amazon.awssdk.services.textract.model.TextractResponse;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@ -52,28 +47,23 @@ import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_A
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
|
||||
@CapabilityDescription("Trigger a AWS Textract job. It should be followed by GetAwsTextractJobStatus processor in order to monitor job status.")
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTextractJobStatus"),
|
||||
@WritesAttribute(attribute = "awsTextractType", description = "The selected Textract type, which can be used in GetAwsTextractJobStatus")
|
||||
})
|
||||
@SeeAlso({GetAwsTextractJobStatus.class})
|
||||
public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonTextractClient, AmazonWebServiceRequest, AmazonWebServiceResult> {
|
||||
public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
|
||||
@Override
|
||||
public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
|
||||
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
|
||||
return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
|
||||
} else if (TextractType.TEXTRACT_TYPES.contains(value)) {
|
||||
return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
|
||||
} else {
|
||||
return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
|
||||
}
|
||||
}
|
||||
};
|
||||
public class StartAwsTextractJob extends AbstractAwsMachineLearningJobStarter<
|
||||
TextractRequest, TextractRequest.Builder, TextractResponse, TextractClient, TextractClientBuilder> {
|
||||
|
||||
public static final String TEXTRACT_TYPE_ATTRIBUTE = "awsTextractType";
|
||||
|
||||
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
|
||||
.name("textract-type")
|
||||
.displayName("Textract Type")
|
||||
.required(true)
|
||||
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
|
||||
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
|
||||
.defaultValue(DOCUMENT_ANALYSIS.type)
|
||||
.addValidator(TEXTRACT_TYPE_VALIDATOR)
|
||||
.allowableValues(TextractType.TEXTRACT_TYPES)
|
||||
.defaultValue(DOCUMENT_ANALYSIS.getType())
|
||||
.build();
|
||||
private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES =
|
||||
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
|
||||
@ -84,24 +74,13 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonText
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void postProcessFlowFile(ProcessContext context, ProcessSession session, FlowFile flowFile, AmazonWebServiceResult response) {
|
||||
super.postProcessFlowFile(context, session, flowFile, response);
|
||||
protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return TextractClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonTextractClient) AmazonTextractClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withCredentials(credentialsProvider)
|
||||
.withClientConfiguration(config)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AmazonWebServiceResult sendRequest(AmazonWebServiceRequest request, ProcessContext context, FlowFile flowFile) {
|
||||
TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
|
||||
protected TextractResponse sendRequest(final TextractRequest request, final ProcessContext context, final FlowFile flowFile) {
|
||||
TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
|
||||
return switch (textractType) {
|
||||
case DOCUMENT_ANALYSIS -> getClient(context).startDocumentAnalysis((StartDocumentAnalysisRequest) request);
|
||||
case DOCUMENT_TEXT_DETECTION -> getClient(context).startDocumentTextDetection((StartDocumentTextDetectionRequest) request);
|
||||
@ -110,22 +89,28 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonText
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<? extends AmazonWebServiceRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
|
||||
final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
|
||||
protected Class<? extends TextractRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
|
||||
final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
|
||||
return switch (typeOfTextract) {
|
||||
case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.class;
|
||||
case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.class;
|
||||
case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.class;
|
||||
case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.serializableBuilderClass();
|
||||
case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.serializableBuilderClass();
|
||||
case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.serializableBuilderClass();
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getAwsTaskId(ProcessContext context, AmazonWebServiceResult amazonWebServiceResult, FlowFile flowFile) {
|
||||
final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
|
||||
protected String getAwsTaskId(final ProcessContext context, final TextractResponse textractResponse, final FlowFile flowFile) {
|
||||
final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
|
||||
return switch (textractType) {
|
||||
case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResult) amazonWebServiceResult).getJobId();
|
||||
case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResult) amazonWebServiceResult).getJobId();
|
||||
case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResult) amazonWebServiceResult).getJobId();
|
||||
case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResponse) textractResponse).jobId();
|
||||
case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResponse) textractResponse).jobId();
|
||||
case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResponse) textractResponse).jobId();
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, FlowFile flowFile, final TextractResponse response) {
|
||||
flowFile = super.postProcessFlowFile(context, session, flowFile, response);
|
||||
return session.putAttribute(flowFile, TEXTRACT_TYPE_ATTRIBUTE, context.getProperty(TEXTRACT_TYPE).getValue());
|
||||
}
|
||||
}
|
||||
|
@ -17,15 +17,6 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.transcribe;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.textract.model.ThrottlingException;
|
||||
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
|
||||
import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
|
||||
import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
|
||||
import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
@ -35,7 +26,13 @@ import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
|
||||
import software.amazon.awssdk.services.transcribe.TranscribeClient;
|
||||
import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
|
||||
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
|
||||
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
|
||||
import software.amazon.awssdk.services.transcribe.model.LimitExceededException;
|
||||
import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
|
||||
@CapabilityDescription("Retrieves the current status of an AWS Transcribe job.")
|
||||
@ -43,55 +40,50 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
|
||||
})
|
||||
public class GetAwsTranscribeJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTranscribeClient> {
|
||||
public class GetAwsTranscribeJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TranscribeClient, TranscribeClientBuilder> {
|
||||
|
||||
@Override
|
||||
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withCredentials(credentialsProvider)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.withClientConfiguration(config)
|
||||
.build();
|
||||
protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return TranscribeClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
|
||||
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
|
||||
FlowFile flowFile = session.get();
|
||||
if (flowFile == null) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
GetTranscriptionJobResult job = getJob(context, flowFile);
|
||||
TranscriptionJobStatus jobStatus = TranscriptionJobStatus.fromValue(job.getTranscriptionJob().getTranscriptionJobStatus());
|
||||
final GetTranscriptionJobResponse job = getJob(context, flowFile);
|
||||
final TranscriptionJobStatus status = job.transcriptionJob().transcriptionJobStatus();
|
||||
|
||||
if (TranscriptionJobStatus.COMPLETED == jobStatus) {
|
||||
writeToFlowFile(session, flowFile, job);
|
||||
session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.getTranscriptionJob().getTranscript().getTranscriptFileUri());
|
||||
if (TranscriptionJobStatus.COMPLETED == status) {
|
||||
flowFile = writeToFlowFile(session, flowFile, job);
|
||||
flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.transcriptionJob().transcript().transcriptFileUri());
|
||||
session.transfer(flowFile, REL_SUCCESS);
|
||||
} else if (TranscriptionJobStatus.IN_PROGRESS == jobStatus) {
|
||||
} else if (TranscriptionJobStatus.IN_PROGRESS == status || TranscriptionJobStatus.QUEUED == status) {
|
||||
session.transfer(flowFile, REL_RUNNING);
|
||||
} else if (TranscriptionJobStatus.FAILED == jobStatus) {
|
||||
final String failureReason = job.getTranscriptionJob().getFailureReason();
|
||||
session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
|
||||
} else if (TranscriptionJobStatus.FAILED == status) {
|
||||
final String failureReason = job.transcriptionJob().failureReason();
|
||||
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
getLogger().error("Transcribe Task Failed {} Reason [{}]", flowFile, failureReason);
|
||||
} else {
|
||||
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
|
||||
throw new IllegalStateException("Unrecognized job status");
|
||||
}
|
||||
} catch (ThrottlingException e) {
|
||||
} catch (final LimitExceededException e) {
|
||||
getLogger().info("Request Rate Limit exceeded", e);
|
||||
session.transfer(flowFile, REL_THROTTLED);
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
getLogger().warn("Failed to get Transcribe Job status", e);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private GetTranscriptionJobResult getJob(ProcessContext context, FlowFile flowFile) {
|
||||
String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
GetTranscriptionJobRequest request = new GetTranscriptionJobRequest().withTranscriptionJobName(taskId);
|
||||
private GetTranscriptionJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
|
||||
final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
final GetTranscriptionJobRequest request = GetTranscriptionJobRequest.builder().transcriptionJobName(taskId).build();
|
||||
return getClient(context).getTranscriptionJob(request);
|
||||
}
|
||||
}
|
||||
|
@ -17,48 +17,45 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.transcribe;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
|
||||
import com.amazonaws.services.transcribe.model.StartTranscriptionJobRequest;
|
||||
import com.amazonaws.services.transcribe.model.StartTranscriptionJobResult;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.SeeAlso;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
|
||||
import software.amazon.awssdk.services.transcribe.TranscribeClient;
|
||||
import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
|
||||
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
|
||||
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
|
||||
@CapabilityDescription("Trigger a AWS Transcribe job. It should be followed by GetAwsTranscribeStatus processor in order to monitor job status.")
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranscribeJobStatus")
|
||||
})
|
||||
@SeeAlso({GetAwsTranscribeJobStatus.class})
|
||||
public class StartAwsTranscribeJob extends AwsMachineLearningJobStarter<AmazonTranscribeClient, StartTranscriptionJobRequest, StartTranscriptionJobResult> {
|
||||
public class StartAwsTranscribeJob extends AbstractAwsMachineLearningJobStarter<
|
||||
StartTranscriptionJobRequest, StartTranscriptionJobRequest.Builder, StartTranscriptionJobResponse, TranscribeClient, TranscribeClientBuilder> {
|
||||
|
||||
@Override
|
||||
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withClientConfiguration(config)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.withCredentials(credentialsProvider)
|
||||
.build();
|
||||
protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return TranscribeClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected StartTranscriptionJobResult sendRequest(StartTranscriptionJobRequest request, ProcessContext context, FlowFile flowFile) {
|
||||
protected StartTranscriptionJobResponse sendRequest(final StartTranscriptionJobRequest request, final ProcessContext context, final FlowFile flowFile) {
|
||||
return getClient(context).startTranscriptionJob(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<? extends StartTranscriptionJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
|
||||
return StartTranscriptionJobRequest.class;
|
||||
protected Class<? extends StartTranscriptionJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
|
||||
return StartTranscriptionJobRequest.serializableBuilderClass();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getAwsTaskId(ProcessContext context, StartTranscriptionJobResult startTranscriptionJobResult, FlowFile flowFile) {
|
||||
return startTranscriptionJobResult.getTranscriptionJob().getTranscriptionJobName();
|
||||
protected String getAwsTaskId(final ProcessContext context, final StartTranscriptionJobResponse startTranscriptionJobResponse, final FlowFile flowFile) {
|
||||
return startTranscriptionJobResponse.transcriptionJob().transcriptionJobName();
|
||||
}
|
||||
}
|
||||
|
@ -17,15 +17,6 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.translate;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.textract.model.ThrottlingException;
|
||||
import com.amazonaws.services.translate.AmazonTranslateClient;
|
||||
import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
|
||||
import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
|
||||
import com.amazonaws.services.translate.model.JobStatus;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
@ -34,8 +25,15 @@ import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processor.ProcessSession;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processor.exception.ProcessException;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
|
||||
import software.amazon.awssdk.services.translate.TranslateClient;
|
||||
import software.amazon.awssdk.services.translate.TranslateClientBuilder;
|
||||
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
|
||||
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
|
||||
import software.amazon.awssdk.services.translate.model.JobStatus;
|
||||
import software.amazon.awssdk.services.translate.model.LimitExceededException;
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
|
||||
@CapabilityDescription("Retrieves the current status of an AWS Translate job.")
|
||||
@ -43,54 +41,66 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
|
||||
})
|
||||
public class GetAwsTranslateJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTranslateClient> {
|
||||
public class GetAwsTranslateJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TranslateClient, TranslateClientBuilder> {
|
||||
|
||||
@Override
|
||||
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonTranslateClient) AmazonTranslateClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withCredentials(credentialsProvider)
|
||||
.withClientConfiguration(config)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.build();
|
||||
protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return TranslateClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
|
||||
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
|
||||
FlowFile flowFile = session.get();
|
||||
if (flowFile == null) {
|
||||
return;
|
||||
}
|
||||
String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
try {
|
||||
DescribeTextTranslationJobResult describeTextTranslationJobResult = getStatusString(context, awsTaskId);
|
||||
JobStatus status = JobStatus.fromValue(describeTextTranslationJobResult.getTextTranslationJobProperties().getJobStatus());
|
||||
final DescribeTextTranslationJobResponse job = getJob(context, flowFile);
|
||||
final JobStatus status = job.textTranslationJobProperties().jobStatus();
|
||||
|
||||
if (status == JobStatus.IN_PROGRESS || status == JobStatus.SUBMITTED) {
|
||||
writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
|
||||
session.penalize(flowFile);
|
||||
session.transfer(flowFile, REL_RUNNING);
|
||||
} else if (status == JobStatus.COMPLETED) {
|
||||
session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, describeTextTranslationJobResult.getTextTranslationJobProperties().getOutputDataConfig().getS3Uri());
|
||||
writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
|
||||
session.transfer(flowFile, REL_SUCCESS);
|
||||
} else if (status == JobStatus.FAILED || status == JobStatus.COMPLETED_WITH_ERROR) {
|
||||
writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
flowFile = writeToFlowFile(session, flowFile, job);
|
||||
final Relationship transferRelationship;
|
||||
String failureReason = null;
|
||||
switch (status) {
|
||||
case IN_PROGRESS:
|
||||
case SUBMITTED:
|
||||
case STOP_REQUESTED:
|
||||
flowFile = session.penalize(flowFile);
|
||||
transferRelationship = REL_RUNNING;
|
||||
break;
|
||||
case COMPLETED:
|
||||
flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.textTranslationJobProperties().outputDataConfig().s3Uri());
|
||||
transferRelationship = REL_SUCCESS;
|
||||
break;
|
||||
case FAILED:
|
||||
case COMPLETED_WITH_ERROR:
|
||||
failureReason = job.textTranslationJobProperties().message();
|
||||
transferRelationship = REL_FAILURE;
|
||||
break;
|
||||
case STOPPED:
|
||||
failureReason = String.format("Job [%s] is stopped", job.textTranslationJobProperties().jobId());
|
||||
transferRelationship = REL_FAILURE;
|
||||
break;
|
||||
default:
|
||||
failureReason = "Unknown Job Status";
|
||||
transferRelationship = REL_FAILURE;
|
||||
}
|
||||
} catch (ThrottlingException e) {
|
||||
if (failureReason != null) {
|
||||
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
|
||||
}
|
||||
session.transfer(flowFile, transferRelationship);
|
||||
} catch (final LimitExceededException e) {
|
||||
getLogger().info("Request Rate Limit exceeded", e);
|
||||
session.transfer(flowFile, REL_THROTTLED);
|
||||
} catch (Exception e) {
|
||||
getLogger().warn("Failed to get Polly Job status", e);
|
||||
} catch (final Exception e) {
|
||||
getLogger().warn("Failed to get Translate Job status", e);
|
||||
session.transfer(flowFile, REL_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
private DescribeTextTranslationJobResult getStatusString(ProcessContext context, String awsTaskId) {
|
||||
DescribeTextTranslationJobRequest request = new DescribeTextTranslationJobRequest().withJobId(awsTaskId);
|
||||
DescribeTextTranslationJobResult translationJobsResult = getClient(context).describeTextTranslationJob(request);
|
||||
return translationJobsResult;
|
||||
private DescribeTextTranslationJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
|
||||
final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
|
||||
final DescribeTextTranslationJobRequest request = DescribeTextTranslationJobRequest.builder().jobId(taskId).build();
|
||||
return getClient(context).describeTextTranslationJob(request);
|
||||
}
|
||||
}
|
||||
|
@ -17,47 +17,44 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.translate;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.translate.AmazonTranslateClient;
|
||||
import com.amazonaws.services.translate.model.StartTextTranslationJobRequest;
|
||||
import com.amazonaws.services.translate.model.StartTextTranslationJobResult;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttribute;
|
||||
import org.apache.nifi.annotation.behavior.WritesAttributes;
|
||||
import org.apache.nifi.annotation.documentation.CapabilityDescription;
|
||||
import org.apache.nifi.annotation.documentation.SeeAlso;
|
||||
import org.apache.nifi.annotation.documentation.Tags;
|
||||
import org.apache.nifi.flowfile.FlowFile;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
|
||||
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
|
||||
import software.amazon.awssdk.services.translate.TranslateClient;
|
||||
import software.amazon.awssdk.services.translate.TranslateClientBuilder;
|
||||
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
|
||||
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
|
||||
|
||||
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
|
||||
@CapabilityDescription("Trigger a AWS Translate job. It should be followed by GetAwsTranslateJobStatus processor in order to monitor job status.")
|
||||
@WritesAttributes({
|
||||
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranslateJobStatus")
|
||||
})
|
||||
@SeeAlso({GetAwsTranslateJobStatus.class})
|
||||
public class StartAwsTranslateJob extends AwsMachineLearningJobStarter<AmazonTranslateClient, StartTextTranslationJobRequest, StartTextTranslationJobResult> {
|
||||
public class StartAwsTranslateJob extends AbstractAwsMachineLearningJobStarter<
|
||||
StartTextTranslationJobRequest, StartTextTranslationJobRequest.Builder, StartTextTranslationJobResponse, TranslateClient, TranslateClientBuilder> {
|
||||
|
||||
@Override
|
||||
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
return (AmazonTranslateClient) AmazonTranslateClient.builder()
|
||||
.withRegion(context.getProperty(REGION).getValue())
|
||||
.withCredentials(credentialsProvider)
|
||||
.withClientConfiguration(config)
|
||||
.withEndpointConfiguration(endpointConfiguration)
|
||||
.build();
|
||||
protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
|
||||
return TranslateClient.builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected StartTextTranslationJobResult sendRequest(StartTextTranslationJobRequest request, ProcessContext context, FlowFile flowFile) {
|
||||
protected StartTextTranslationJobResponse sendRequest(final StartTextTranslationJobRequest request, final ProcessContext context, final FlowFile flowFile) {
|
||||
return getClient(context).startTextTranslationJob(request);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<StartTextTranslationJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
|
||||
return StartTextTranslationJobRequest.class;
|
||||
protected Class<? extends StartTextTranslationJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
|
||||
return StartTextTranslationJobRequest.serializableBuilderClass();
|
||||
}
|
||||
|
||||
protected String getAwsTaskId(ProcessContext context, StartTextTranslationJobResult startTextTranslationJobResult, FlowFile flowFile) {
|
||||
return startTextTranslationJobResult.getJobId();
|
||||
protected String getAwsTaskId(final ProcessContext context, final StartTextTranslationJobResponse startTextTranslationJobResponse, final FlowFile flowFile) {
|
||||
return startTextTranslationJobResponse.jobId();
|
||||
}
|
||||
}
|
||||
|
@ -17,18 +17,10 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.polly;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.polly.AmazonPollyClient;
|
||||
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
|
||||
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
|
||||
import com.amazonaws.services.polly.model.SynthesisTask;
|
||||
import com.amazonaws.services.polly.model.TaskStatus;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.MockFlowFile;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@ -38,16 +30,20 @@ import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.services.polly.PollyClient;
|
||||
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
|
||||
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
|
||||
import software.amazon.awssdk.services.polly.model.SynthesisTask;
|
||||
import software.amazon.awssdk.services.polly.model.TaskStatus;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ -57,73 +53,84 @@ public class GetAwsPollyStatusTest {
|
||||
private static final String PLACEHOLDER_CONTENT = "content";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private AmazonPollyClient mockPollyClient;
|
||||
@Mock
|
||||
private AWSCredentialsProviderService mockAwsCredentialsProvider;
|
||||
private PollyClient mockPollyClient;
|
||||
|
||||
private GetAwsPollyJobStatus processor;
|
||||
|
||||
@Captor
|
||||
private ArgumentCaptor<GetSpeechSynthesisTaskRequest> requestCaptor;
|
||||
|
||||
private TestRunner createRunner(final GetAwsPollyJobStatus processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider");
|
||||
|
||||
final GetAwsPollyJobStatus mockGetAwsPollyStatus = new GetAwsPollyJobStatus() {
|
||||
processor = new GetAwsPollyJobStatus() {
|
||||
@Override
|
||||
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
public PollyClient getClient(final ProcessContext context) {
|
||||
return mockPollyClient;
|
||||
}
|
||||
};
|
||||
runner = TestRunners.newTestRunner(mockGetAwsPollyStatus);
|
||||
runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
|
||||
runner.enableControllerService(mockAwsCredentialsProvider);
|
||||
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPollyTaskInProgress() {
|
||||
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
|
||||
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
|
||||
.withTaskStatus(TaskStatus.InProgress);
|
||||
taskResult.setSynthesisTask(task);
|
||||
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
|
||||
GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
|
||||
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).taskStatus(TaskStatus.IN_PROGRESS).build())
|
||||
.build();
|
||||
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_RUNNING);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPollyTaskCompleted() {
|
||||
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
|
||||
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
|
||||
.withTaskStatus(TaskStatus.Completed)
|
||||
.withOutputUri("outputLocationPath");
|
||||
taskResult.setSynthesisTask(task);
|
||||
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
|
||||
final String uri = "https://s3.us-west2.amazonaws.com/bucket/object";
|
||||
final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
|
||||
.synthesisTask(SynthesisTask.builder()
|
||||
.taskId(TEST_TASK_ID)
|
||||
.taskStatus(TaskStatus.COMPLETED)
|
||||
.outputUri(uri).build())
|
||||
.build();
|
||||
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
runner.assertAllFlowFilesContainAttribute(REL_SUCCESS, AWS_TASK_OUTPUT_LOCATION);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
|
||||
}
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
|
||||
|
||||
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
|
||||
assertEquals(uri, flowFile.getAttribute(GetAwsPollyJobStatus.AWS_TASK_OUTPUT_LOCATION));
|
||||
assertEquals("bucket", flowFile.getAttribute("PollyS3OutputBucket"));
|
||||
assertEquals("object", flowFile.getAttribute("filename"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPollyTaskFailed() {
|
||||
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
|
||||
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
|
||||
.withTaskStatus(TaskStatus.Failed)
|
||||
.withTaskStatusReason("reasonOfFailure");
|
||||
taskResult.setSynthesisTask(task);
|
||||
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
|
||||
final String failureReason = "reasonOfFailure";
|
||||
final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
|
||||
.synthesisTask(SynthesisTask.builder()
|
||||
.taskId(TEST_TASK_ID)
|
||||
.taskStatus(TaskStatus.FAILED)
|
||||
.taskStatusReason(failureReason).build())
|
||||
.build();
|
||||
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
|
||||
|
||||
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
|
||||
assertEquals(failureReason, flowFile.getAttribute(GetAwsPollyJobStatus.FAILURE_REASON_ATTRIBUTE));
|
||||
}
|
||||
}
|
@ -0,0 +1,166 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.polly;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.awscore.exception.AwsServiceException;
|
||||
import software.amazon.awssdk.services.polly.PollyClient;
|
||||
import software.amazon.awssdk.services.polly.model.Engine;
|
||||
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
|
||||
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
|
||||
import software.amazon.awssdk.services.polly.model.SynthesisTask;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class StartAwsPollyJobTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private PollyClient mockPollyClient;
|
||||
|
||||
private StartAwsPollyJob processor;
|
||||
|
||||
private ObjectMapper objectMapper = JsonMapper.builder()
|
||||
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
|
||||
.build();
|
||||
@Captor
|
||||
private ArgumentCaptor<StartSpeechSynthesisTaskRequest> requestCaptor;
|
||||
|
||||
private TestRunner createRunner(final StartAwsPollyJob processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
processor = new StartAwsPollyJob() {
|
||||
@Override
|
||||
public PollyClient getClient(ProcessContext context) {
|
||||
return mockPollyClient;
|
||||
}
|
||||
};
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
|
||||
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
|
||||
.engine(Engine.NEURAL)
|
||||
.text("Text")
|
||||
.build();
|
||||
final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
|
||||
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
|
||||
.build();
|
||||
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
|
||||
|
||||
assertEquals("Text", requestCaptor.getValue().text());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulAttribute() throws JsonProcessingException {
|
||||
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
|
||||
.engine(Engine.NEURAL)
|
||||
.text("Text")
|
||||
.build();
|
||||
final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
|
||||
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
|
||||
.build();
|
||||
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsPollyJob.JSON_PAYLOAD, "${json.payload}");
|
||||
final Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("json.payload", requestJson);
|
||||
runner.enqueue("", attributes);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
|
||||
|
||||
assertEquals("Text", requestCaptor.getValue().text());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidJson() {
|
||||
final String requestJson = "invalid";
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testServiceFailure() throws JsonProcessingException {
|
||||
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
|
||||
.engine(Engine.NEURAL)
|
||||
.text("Text")
|
||||
.build();
|
||||
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
private StartSpeechSynthesisTaskResponse deserialize(final String responseData) throws JsonProcessingException {
|
||||
return objectMapper.readValue(responseData, StartSpeechSynthesisTaskResponse.serializableBuilderClass()).build();
|
||||
}
|
||||
|
||||
private String serialize(final StartSpeechSynthesisTaskRequest request) throws JsonProcessingException {
|
||||
return objectMapper.writeValueAsString(request.toBuilder());
|
||||
}
|
||||
}
|
@ -17,17 +17,10 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.textract;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.textract.AmazonTextractClient;
|
||||
import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
|
||||
import com.amazonaws.services.textract.model.GetDocumentAnalysisResult;
|
||||
import com.amazonaws.services.textract.model.JobStatus;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
@ -38,13 +31,20 @@ import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.services.textract.TextractClient;
|
||||
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisResponse;
|
||||
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
|
||||
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionResponse;
|
||||
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisResponse;
|
||||
import software.amazon.awssdk.services.textract.model.JobStatus;
|
||||
|
||||
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_THROTTLED;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ -53,64 +53,144 @@ public class GetAwsTextractJobStatusTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private AmazonTextractClient mockTextractClient;
|
||||
@Mock
|
||||
private AWSCredentialsProviderService mockAwsCredentialsProvider;
|
||||
private TextractClient mockTextractClient;
|
||||
|
||||
private GetAwsTextractJobStatus processor;
|
||||
|
||||
@Captor
|
||||
private ArgumentCaptor<GetDocumentAnalysisRequest> requestCaptor;
|
||||
private ArgumentCaptor<GetDocumentAnalysisRequest> documentAnalysisCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<GetExpenseAnalysisRequest> expenseAnalysisRequestCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<GetDocumentTextDetectionRequest> documentTextDetectionCaptor;
|
||||
|
||||
private TestRunner createRunner(final GetAwsTextractJobStatus processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider");
|
||||
final GetAwsTextractJobStatus awsTextractJobStatusGetter = new GetAwsTextractJobStatus() {
|
||||
processor = new GetAwsTextractJobStatus() {
|
||||
@Override
|
||||
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
public TextractClient getClient(final ProcessContext context) {
|
||||
return mockTextractClient;
|
||||
}
|
||||
};
|
||||
runner = TestRunners.newTestRunner(awsTextractJobStatusGetter);
|
||||
runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
|
||||
runner.enableControllerService(mockAwsCredentialsProvider);
|
||||
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocAnalysisTaskInProgress() {
|
||||
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
|
||||
.withJobStatus(JobStatus.IN_PROGRESS);
|
||||
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
|
||||
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_RUNNING);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
|
||||
testTextractDocAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocAnalysisTaskComplete() {
|
||||
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
|
||||
.withJobStatus(JobStatus.SUCCEEDED);
|
||||
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
|
||||
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
|
||||
testTextractDocAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocAnalysisTaskFailed() {
|
||||
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
|
||||
.withJobStatus(JobStatus.FAILED);
|
||||
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
|
||||
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.type));
|
||||
testTextractDocAnalysis(JobStatus.FAILED, REL_FAILURE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocAnalysisTaskPartialSuccess() {
|
||||
testTextractDocAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocAnalysisTaskUnkownStatus() {
|
||||
testTextractDocAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
|
||||
}
|
||||
|
||||
private void testTextractDocAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
|
||||
final GetDocumentAnalysisResponse response = GetDocumentAnalysisResponse.builder()
|
||||
.jobStatus(jobStatus).build();
|
||||
when(mockTextractClient.getDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue("content", ImmutableMap.of(
|
||||
TASK_ID.getName(), TEST_TASK_ID,
|
||||
StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE, TextractType.DOCUMENT_ANALYSIS.getType()));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
|
||||
runner.assertAllFlowFilesTransferred(expectedRelationship);
|
||||
assertEquals(TEST_TASK_ID, documentAnalysisCaptor.getValue().jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractExpenseAnalysisTaskInProgress() {
|
||||
testTextractExpenseAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractExpenseAnalysisTaskComplete() {
|
||||
testTextractExpenseAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractExpenseAnalysisTaskFailed() {
|
||||
testTextractExpenseAnalysis(JobStatus.FAILED, REL_FAILURE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractExpenseAnalysisTaskPartialSuccess() {
|
||||
testTextractExpenseAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractExpenseAnalysisTaskUnkownStatus() {
|
||||
testTextractExpenseAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
|
||||
}
|
||||
|
||||
private void testTextractExpenseAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
|
||||
runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.EXPENSE_ANALYSIS.getType());
|
||||
final GetExpenseAnalysisResponse response = GetExpenseAnalysisResponse.builder()
|
||||
.jobStatus(jobStatus).build();
|
||||
when(mockTextractClient.getExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(expectedRelationship);
|
||||
assertEquals(TEST_TASK_ID, expenseAnalysisRequestCaptor.getValue().jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocumentTextDetectionTaskInProgress() {
|
||||
testTextractDocumentTextDetection(JobStatus.IN_PROGRESS, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocumentTextDetectionTaskComplete() {
|
||||
testTextractDocumentTextDetection(JobStatus.SUCCEEDED, REL_SUCCESS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocumentTextDetectionTaskFailed() {
|
||||
testTextractDocumentTextDetection(JobStatus.FAILED, REL_FAILURE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocumentTextDetectionTaskPartialSuccess() {
|
||||
testTextractDocumentTextDetection(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTextractDocumentTextDetectionTaskUnkownStatus() {
|
||||
testTextractDocumentTextDetection(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
|
||||
}
|
||||
|
||||
private void testTextractDocumentTextDetection(final JobStatus jobStatus, final Relationship expectedRelationship) {
|
||||
runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.DOCUMENT_TEXT_DETECTION.getType());
|
||||
|
||||
final GetDocumentTextDetectionResponse response = GetDocumentTextDetectionResponse.builder()
|
||||
.jobStatus(jobStatus).build();
|
||||
when(mockTextractClient.getDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(expectedRelationship);
|
||||
assertEquals(TEST_TASK_ID, documentTextDetectionCaptor.getValue().jobId());
|
||||
}
|
||||
}
|
@ -0,0 +1,342 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.textract;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.awscore.exception.AwsServiceException;
|
||||
import software.amazon.awssdk.services.textract.TextractClient;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
|
||||
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
|
||||
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
|
||||
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
|
||||
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_TEXT_DETECTION;
|
||||
import static org.apache.nifi.processors.aws.ml.textract.TextractType.EXPENSE_ANALYSIS;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class StartAwsTextractJobStatusTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private TextractClient mockTextractClient;
|
||||
|
||||
private StartAwsTextractJob processor;
|
||||
|
||||
private ObjectMapper objectMapper = JsonMapper.builder()
|
||||
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
|
||||
.build();
|
||||
@Captor
|
||||
private ArgumentCaptor<StartDocumentAnalysisRequest> documentAnalysisCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<StartExpenseAnalysisRequest> expenseAnalysisRequestCaptor;
|
||||
@Captor
|
||||
private ArgumentCaptor<StartDocumentTextDetectionRequest> documentTextDetectionCaptor;
|
||||
|
||||
private TestRunner createRunner(final StartAwsTextractJob processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
processor = new StartAwsTextractJob() {
|
||||
@Override
|
||||
public TextractClient getClient(ProcessContext context) {
|
||||
return mockTextractClient;
|
||||
}
|
||||
};
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulDocumentAnalysisFlowfileContent() throws JsonProcessingException {
|
||||
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
|
||||
|
||||
assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulDocumentAnalysisAttribute() throws JsonProcessingException {
|
||||
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
|
||||
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
|
||||
final Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("json.payload", requestJson);
|
||||
runner.enqueue("", attributes);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
|
||||
|
||||
assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidDocumentAnalysisJson() {
|
||||
final String requestJson = "invalid";
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDocumentAnalysisServiceFailure() throws JsonProcessingException {
|
||||
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulExpenseAnalysisFlowfileContent() throws JsonProcessingException {
|
||||
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
|
||||
|
||||
assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulExpenseAnalysisAttribute() throws JsonProcessingException {
|
||||
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
|
||||
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
|
||||
final Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("json.payload", requestJson);
|
||||
runner.enqueue("", attributes);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
|
||||
|
||||
assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidExpenseAnalysisJson() {
|
||||
final String requestJson = "invalid";
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testExpenseAnalysisServiceFailure() throws JsonProcessingException {
|
||||
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulDocumentTextDetectionFlowfileContent() throws JsonProcessingException {
|
||||
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
|
||||
|
||||
assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulDocumentTextDetectionAttribute() throws JsonProcessingException {
|
||||
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
|
||||
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
|
||||
final Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("json.payload", requestJson);
|
||||
runner.enqueue("", attributes);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
|
||||
|
||||
assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidDocumentTextDetectionJson() {
|
||||
final String requestJson = "invalid";
|
||||
runner.enqueue(requestJson);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDocumentTextDetectionServiceFailure() throws JsonProcessingException {
|
||||
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
|
||||
.jobTag("Tag")
|
||||
.build();
|
||||
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
private StartDocumentTextDetectionResponse deserializeDTDRequest(final String responseData) throws JsonProcessingException {
|
||||
return objectMapper.readValue(responseData, StartDocumentTextDetectionResponse.serializableBuilderClass()).build();
|
||||
}
|
||||
|
||||
private StartDocumentAnalysisResponse deserializeDARequest(final String responseData) throws JsonProcessingException {
|
||||
return objectMapper.readValue(responseData, StartDocumentAnalysisResponse.serializableBuilderClass()).build();
|
||||
}
|
||||
|
||||
private StartExpenseAnalysisResponse deserializeEARequest(final String responseData) throws JsonProcessingException {
|
||||
return objectMapper.readValue(responseData, StartExpenseAnalysisResponse.serializableBuilderClass()).build();
|
||||
}
|
||||
|
||||
private String serialize(final StartDocumentAnalysisRequest request) throws JsonProcessingException {
|
||||
return objectMapper.writeValueAsString(request.toBuilder());
|
||||
}
|
||||
|
||||
private String serialize(final StartExpenseAnalysisRequest request) throws JsonProcessingException {
|
||||
return objectMapper.writeValueAsString(request.toBuilder());
|
||||
}
|
||||
|
||||
private String serialize(final StartDocumentTextDetectionRequest request) throws JsonProcessingException {
|
||||
return objectMapper.writeValueAsString(request.toBuilder());
|
||||
}
|
||||
}
|
@ -17,19 +17,11 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.transcribe;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
|
||||
import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
|
||||
import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
|
||||
import com.amazonaws.services.transcribe.model.Transcript;
|
||||
import com.amazonaws.services.transcribe.model.TranscriptionJob;
|
||||
import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.MockFlowFile;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@ -39,93 +31,118 @@ import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.services.transcribe.TranscribeClient;
|
||||
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
|
||||
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
|
||||
import software.amazon.awssdk.services.transcribe.model.Transcript;
|
||||
import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
|
||||
import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class GetAwsTranscribeJobStatusTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private static final String AWS_CREDENTIAL_PROVIDER_NAME = "awsCredetialProvider";
|
||||
private static final String TEST_TASK_ID = "testJobId";
|
||||
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
|
||||
private static final String REASON_OF_FAILURE = "reasonOfFailure";
|
||||
private static final String CONTENT_STRING = "content";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private AmazonTranscribeClient mockTranscribeClient;
|
||||
@Mock
|
||||
private AWSCredentialsProviderService mockAwsCredentialsProvider;
|
||||
private TranscribeClient mockTranscribeClient;
|
||||
|
||||
private GetAwsTranscribeJobStatus processor;
|
||||
|
||||
@Captor
|
||||
private ArgumentCaptor<GetTranscriptionJobRequest> requestCaptor;
|
||||
|
||||
private TestRunner createRunner(final GetAwsTranscribeJobStatus processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIAL_PROVIDER_NAME);
|
||||
final GetAwsTranscribeJobStatus mockPollyFetcher = new GetAwsTranscribeJobStatus() {
|
||||
processor = new GetAwsTranscribeJobStatus() {
|
||||
@Override
|
||||
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
public TranscribeClient getClient(final ProcessContext context) {
|
||||
return mockTranscribeClient;
|
||||
}
|
||||
};
|
||||
runner = TestRunners.newTestRunner(mockPollyFetcher);
|
||||
runner.addControllerService(AWS_CREDENTIAL_PROVIDER_NAME, mockAwsCredentialsProvider);
|
||||
runner.enableControllerService(mockAwsCredentialsProvider);
|
||||
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIAL_PROVIDER_NAME);
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeTaskInProgress() {
|
||||
TranscriptionJob task = new TranscriptionJob()
|
||||
.withTranscriptionJobName(TEST_TASK_ID)
|
||||
.withTranscriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS);
|
||||
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
|
||||
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_RUNNING);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
|
||||
public void testTranscribeJobInProgress() {
|
||||
final TranscriptionJob job = TranscriptionJob.builder()
|
||||
.transcriptionJobName(TEST_TASK_ID)
|
||||
.transcriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS)
|
||||
.build();
|
||||
testTranscribeJob(job, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeTaskCompleted() {
|
||||
TranscriptionJob task = new TranscriptionJob()
|
||||
.withTranscriptionJobName(TEST_TASK_ID)
|
||||
.withTranscript(new Transcript().withTranscriptFileUri(OUTPUT_LOCATION_PATH))
|
||||
.withTranscriptionJobStatus(TranscriptionJobStatus.COMPLETED);
|
||||
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
|
||||
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
|
||||
public void testTranscribeJobQueued() {
|
||||
final TranscriptionJob job = TranscriptionJob.builder()
|
||||
.transcriptionJobName(TEST_TASK_ID)
|
||||
.transcriptionJobStatus(TranscriptionJobStatus.QUEUED)
|
||||
.build();
|
||||
testTranscribeJob(job, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeJobCompleted() {
|
||||
final TranscriptionJob job = TranscriptionJob.builder()
|
||||
.transcriptionJobName(TEST_TASK_ID)
|
||||
.transcript(Transcript.builder().transcriptFileUri(OUTPUT_LOCATION_PATH).build())
|
||||
.transcriptionJobStatus(TranscriptionJobStatus.COMPLETED)
|
||||
.build();
|
||||
testTranscribeJob(job, REL_SUCCESS);
|
||||
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
|
||||
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
|
||||
assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPollyTaskFailed() {
|
||||
TranscriptionJob task = new TranscriptionJob()
|
||||
.withTranscriptionJobName(TEST_TASK_ID)
|
||||
.withFailureReason(REASON_OF_FAILURE)
|
||||
.withTranscriptionJobStatus(TranscriptionJobStatus.FAILED);
|
||||
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
|
||||
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE);
|
||||
public void testTranscribeJobFailed() {
|
||||
final TranscriptionJob job = TranscriptionJob.builder()
|
||||
.transcriptionJobName(TEST_TASK_ID)
|
||||
.failureReason(REASON_OF_FAILURE)
|
||||
.transcriptionJobStatus(TranscriptionJobStatus.FAILED)
|
||||
.build();
|
||||
testTranscribeJob(job, REL_FAILURE);
|
||||
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
|
||||
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
|
||||
assertEquals(REASON_OF_FAILURE, flowFile.getAttribute(FAILURE_REASON_ATTRIBUTE));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeJobUnrecognized() {
|
||||
final TranscriptionJob job = TranscriptionJob.builder()
|
||||
.transcriptionJobName(TEST_TASK_ID)
|
||||
.failureReason(REASON_OF_FAILURE)
|
||||
.transcriptionJobStatus(TranscriptionJobStatus.UNKNOWN_TO_SDK_VERSION)
|
||||
.build();
|
||||
testTranscribeJob(job, REL_FAILURE);
|
||||
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
|
||||
}
|
||||
|
||||
private void testTranscribeJob(final TranscriptionJob job, final Relationship expectedRelationship) {
|
||||
final GetTranscriptionJobResponse response = GetTranscriptionJobResponse.builder().transcriptionJob(job).build();
|
||||
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(expectedRelationship);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().transcriptionJobName());
|
||||
}
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.transcribe;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.awscore.exception.AwsServiceException;
|
||||
import software.amazon.awssdk.services.transcribe.TranscribeClient;
|
||||
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
|
||||
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
|
||||
import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class StartAwsTranscribeJobTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private TranscribeClient mockTranscribeClient;
|
||||
|
||||
private StartAwsTranscribeJob processor;
|
||||
|
||||
private ObjectMapper objectMapper = JsonMapper.builder()
|
||||
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
|
||||
.build();
|
||||
@Captor
|
||||
private ArgumentCaptor<StartTranscriptionJobRequest> requestCaptor;
|
||||
|
||||
private TestRunner createRunner(final StartAwsTranscribeJob processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
processor = new StartAwsTranscribeJob() {
|
||||
@Override
|
||||
public TranscribeClient getClient(ProcessContext context) {
|
||||
return mockTranscribeClient;
|
||||
}
|
||||
};
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
|
||||
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
|
||||
.transcriptionJobName("Job")
|
||||
.build();
|
||||
final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
|
||||
.transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
|
||||
.build();
|
||||
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
|
||||
|
||||
assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulAttribute() throws JsonProcessingException {
|
||||
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
|
||||
.transcriptionJobName("Job")
|
||||
.build();
|
||||
final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
|
||||
.transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
|
||||
.build();
|
||||
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTranscribeJob.JSON_PAYLOAD, "${json.payload}");
|
||||
final Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("json.payload", requestJson);
|
||||
runner.enqueue("", attributes);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
|
||||
|
||||
assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidJson() {
|
||||
final String requestJson = "invalid";
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testServiceFailure() throws JsonProcessingException {
|
||||
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
|
||||
.transcriptionJobName("Job")
|
||||
.build();
|
||||
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
private StartTranscriptionJobResponse deserialize(final String responseData) throws JsonProcessingException {
|
||||
return objectMapper.readValue(responseData, StartTranscriptionJobResponse.serializableBuilderClass()).build();
|
||||
}
|
||||
|
||||
private String serialize(final StartTranscriptionJobRequest request) throws JsonProcessingException {
|
||||
return objectMapper.writeValueAsString(request.toBuilder());
|
||||
}
|
||||
}
|
@ -17,19 +17,11 @@
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.translate;
|
||||
|
||||
import com.amazonaws.ClientConfiguration;
|
||||
import com.amazonaws.auth.AWSCredentialsProvider;
|
||||
import com.amazonaws.client.builder.AwsClientBuilder;
|
||||
import com.amazonaws.regions.Region;
|
||||
import com.amazonaws.services.translate.AmazonTranslateClient;
|
||||
import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
|
||||
import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
|
||||
import com.amazonaws.services.translate.model.JobStatus;
|
||||
import com.amazonaws.services.translate.model.OutputDataConfig;
|
||||
import com.amazonaws.services.translate.model.TextTranslationJobProperties;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
|
||||
import org.apache.nifi.processor.Relationship;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.MockFlowFile;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
@ -39,91 +31,132 @@ import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.services.translate.TranslateClient;
|
||||
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
|
||||
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
|
||||
import software.amazon.awssdk.services.translate.model.JobStatus;
|
||||
import software.amazon.awssdk.services.translate.model.OutputDataConfig;
|
||||
import software.amazon.awssdk.services.translate.model.TextTranslationJobProperties;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class GetAwsTranslateJobStatusTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private static final String TEST_TASK_ID = "testJobId";
|
||||
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
|
||||
private static final String REASON_OF_FAILURE = "reasonOfFailure";
|
||||
private static final String CONTENT_STRING = "content";
|
||||
private static final String AWS_CREDENTIALS_PROVIDER_NAME = "awsCredetialProvider";
|
||||
private static final String OUTPUT_LOCATION_PATH = "outputLocation";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private AmazonTranslateClient mockTranslateClient;
|
||||
@Mock
|
||||
private AWSCredentialsProviderService mockAwsCredentialsProvider;
|
||||
private TranslateClient mockTranslateClient;
|
||||
|
||||
private GetAwsTranslateJobStatus processor;
|
||||
|
||||
@Captor
|
||||
private ArgumentCaptor<DescribeTextTranslationJobRequest> requestCaptor;
|
||||
|
||||
private TestRunner createRunner(final GetAwsTranslateJobStatus processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIALS_PROVIDER_NAME);
|
||||
final GetAwsTranslateJobStatus mockPollyFetcher = new GetAwsTranslateJobStatus() {
|
||||
processor = new GetAwsTranslateJobStatus() {
|
||||
@Override
|
||||
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
|
||||
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
|
||||
public TranslateClient getClient(final ProcessContext context) {
|
||||
return mockTranslateClient;
|
||||
}
|
||||
};
|
||||
runner = TestRunners.newTestRunner(mockPollyFetcher);
|
||||
runner.addControllerService(AWS_CREDENTIALS_PROVIDER_NAME, mockAwsCredentialsProvider);
|
||||
runner.enableControllerService(mockAwsCredentialsProvider);
|
||||
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIALS_PROVIDER_NAME);
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeTaskInProgress() {
|
||||
TextTranslationJobProperties task = new TextTranslationJobProperties()
|
||||
.withJobId(TEST_TASK_ID)
|
||||
.withJobStatus(JobStatus.IN_PROGRESS);
|
||||
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
|
||||
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_RUNNING);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
|
||||
public void testTranslateJobInProgress() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.jobStatus(JobStatus.IN_PROGRESS)
|
||||
.build();
|
||||
testTranslateJob(job, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeTaskCompleted() {
|
||||
TextTranslationJobProperties task = new TextTranslationJobProperties()
|
||||
.withJobId(TEST_TASK_ID)
|
||||
.withOutputDataConfig(new OutputDataConfig().withS3Uri(OUTPUT_LOCATION_PATH))
|
||||
.withJobStatus(JobStatus.COMPLETED);
|
||||
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
|
||||
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
public void testTranslateSubmitted() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.jobStatus(JobStatus.SUBMITTED)
|
||||
.build();
|
||||
testTranslateJob(job, REL_RUNNING);
|
||||
}
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
|
||||
@Test
|
||||
public void testTranslateStopRequested() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.jobStatus(JobStatus.STOP_REQUESTED)
|
||||
.build();
|
||||
testTranslateJob(job, REL_RUNNING);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranslateJobCompleted() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobStatus(TEST_TASK_ID)
|
||||
.outputDataConfig(OutputDataConfig.builder().s3Uri(OUTPUT_LOCATION_PATH).build())
|
||||
.submittedTime(Instant.now())
|
||||
.jobStatus(JobStatus.COMPLETED)
|
||||
.build();
|
||||
testTranslateJob(job, REL_SUCCESS);
|
||||
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
|
||||
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
|
||||
assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranscribeTaskFailed() {
|
||||
TextTranslationJobProperties task = new TextTranslationJobProperties()
|
||||
.withJobId(TEST_TASK_ID)
|
||||
.withJobStatus(JobStatus.FAILED);
|
||||
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
|
||||
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
|
||||
public void testTranslateJobFailed() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobStatus(TEST_TASK_ID)
|
||||
.jobStatus(JobStatus.FAILED)
|
||||
.build();
|
||||
testTranslateJob(job, REL_FAILURE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranslateJobStopped() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobStatus(TEST_TASK_ID)
|
||||
.jobStatus(JobStatus.STOPPED)
|
||||
.build();
|
||||
testTranslateJob(job, REL_FAILURE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTranslateJobUnrecognized() {
|
||||
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
|
||||
.jobStatus(TEST_TASK_ID)
|
||||
.jobStatus(JobStatus.UNKNOWN_TO_SDK_VERSION)
|
||||
.build();
|
||||
testTranslateJob(job, REL_FAILURE);
|
||||
}
|
||||
|
||||
private void testTranslateJob(final TextTranslationJobProperties job, final Relationship expectedRelationship) {
|
||||
final DescribeTextTranslationJobResponse response = DescribeTextTranslationJobResponse.builder().textTranslationJobProperties(job).build();
|
||||
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(response);
|
||||
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
|
||||
runner.assertAllFlowFilesTransferred(expectedRelationship);
|
||||
assertEquals(TEST_TASK_ID, requestCaptor.getValue().jobId());
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.nifi.processors.aws.ml.translate;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.MapperFeature;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||
import org.apache.nifi.processor.ProcessContext;
|
||||
import org.apache.nifi.processors.aws.testutil.AuthUtils;
|
||||
import org.apache.nifi.reporting.InitializationException;
|
||||
import org.apache.nifi.util.TestRunner;
|
||||
import org.apache.nifi.util.TestRunners;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import software.amazon.awssdk.awscore.exception.AwsServiceException;
|
||||
import software.amazon.awssdk.services.translate.TranslateClient;
|
||||
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
|
||||
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
|
||||
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class StartAwsTranslateJobTest {
|
||||
private static final String TEST_TASK_ID = "testTaskId";
|
||||
private TestRunner runner;
|
||||
@Mock
|
||||
private TranslateClient mockTranslateClient;
|
||||
|
||||
private StartAwsTranslateJob processor;
|
||||
|
||||
private ObjectMapper objectMapper = JsonMapper.builder()
|
||||
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
|
||||
.build();
|
||||
@Captor
|
||||
private ArgumentCaptor<StartTextTranslationJobRequest> requestCaptor;
|
||||
|
||||
private TestRunner createRunner(final StartAwsTranslateJob processor) {
|
||||
final TestRunner runner = TestRunners.newTestRunner(processor);
|
||||
AuthUtils.enableAccessKey(runner, "abcd", "defg");
|
||||
return runner;
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() throws InitializationException {
|
||||
processor = new StartAwsTranslateJob() {
|
||||
@Override
|
||||
public TranslateClient getClient(ProcessContext context) {
|
||||
return mockTranslateClient;
|
||||
}
|
||||
};
|
||||
runner = createRunner(processor);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
|
||||
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
|
||||
.terminologyNames("Name")
|
||||
.build();
|
||||
final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
|
||||
|
||||
assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulAttribute() throws JsonProcessingException {
|
||||
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
|
||||
.terminologyNames("Name")
|
||||
.build();
|
||||
final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
|
||||
.jobId(TEST_TASK_ID)
|
||||
.build();
|
||||
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.setProperty(StartAwsTranslateJob.JSON_PAYLOAD, "${json.payload}");
|
||||
final Map<String, String> attributes = new HashMap<>();
|
||||
attributes.put("json.payload", requestJson);
|
||||
runner.enqueue("", attributes);
|
||||
runner.run();
|
||||
|
||||
runner.assertTransferCount(REL_SUCCESS, 1);
|
||||
runner.assertTransferCount(REL_ORIGINAL, 1);
|
||||
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
|
||||
final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
|
||||
|
||||
assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
|
||||
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidJson() {
|
||||
final String requestJson = "invalid";
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testServiceFailure() throws JsonProcessingException {
|
||||
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
|
||||
.terminologyNames("Name")
|
||||
.build();
|
||||
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
|
||||
|
||||
final String requestJson = serialize(request);
|
||||
runner.enqueue(requestJson);
|
||||
runner.run();
|
||||
|
||||
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
|
||||
}
|
||||
|
||||
private StartTextTranslationJobResponse deserialize(final String responseData) throws JsonProcessingException {
|
||||
return objectMapper.readValue(responseData, StartTextTranslationJobResponse.serializableBuilderClass()).build();
|
||||
}
|
||||
|
||||
private String serialize(final StartTextTranslationJobRequest request) throws JsonProcessingException {
|
||||
return objectMapper.writeValueAsString(request.toBuilder());
|
||||
}
|
||||
}
|
@ -66,7 +66,6 @@ public abstract class AbstractSQSIT {
|
||||
queueUrl = response.queueUrl();
|
||||
}
|
||||
|
||||
|
||||
@AfterAll
|
||||
public static void shutdown() {
|
||||
client.close();
|
||||
|
Loading…
x
Reference in New Issue
Block a user