NIFI-12263 Upgraded AWS Machine Learning processors to SDK 2

This closes #7953

Signed-off-by: David Handermann <exceptionfactory@apache.org>
This commit is contained in:
Joe Gresock 2023-10-28 11:40:59 -04:00 committed by exceptionfactory
parent c706877147
commit 77834c92df
No known key found for this signature in database
GPG Key ID: 29B6A52D2AAE8DBA
21 changed files with 1542 additions and 608 deletions

View File

@ -65,6 +65,22 @@
<groupId>software.amazon.awssdk</groupId>
<artifactId>firehose</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>polly</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>textract</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>transcribe</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>translate</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.kinesis</groupId>
<artifactId>amazon-kinesis-client</artifactId>
@ -130,6 +146,10 @@
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
</dependency>
<!-- Version 2 of the AmazonS3EncryptionClient requires bouncy castle -->
<dependency>
<groupId>org.bouncycastle</groupId>

View File

@ -17,14 +17,6 @@
package org.apache.nifi.processors.aws.ml;
import com.amazonaws.AmazonWebServiceClient;
import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.AmazonWebServiceResult;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
@ -33,12 +25,18 @@ import org.apache.commons.io.IOUtils;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor;
import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
import software.amazon.awssdk.awscore.AwsRequest;
import software.amazon.awssdk.awscore.AwsResponse;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.core.SdkClient;
import java.io.IOException;
import java.io.InputStream;
@ -46,10 +44,15 @@ import java.util.List;
import java.util.Set;
import static org.apache.nifi.flowfile.attributes.CoreAttributes.MIME_TYPE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceClient, REQUEST extends AmazonWebServiceRequest, RESPONSE extends AmazonWebServiceResult>
extends AbstractAWSCredentialsProviderProcessor<T> {
public abstract class AbstractAwsMachineLearningJobStarter<
Q extends AwsRequest,
B extends AwsRequest.Builder,
R extends AwsResponse,
T extends SdkClient,
U extends AwsSyncClientBuilder<U, T> & AwsClientBuilder<U, T>>
extends AbstractAwsSyncProcessor<T, U> {
public static final PropertyDescriptor JSON_PAYLOAD = new PropertyDescriptor.Builder()
.name("json-payload")
.displayName("JSON Payload")
@ -62,18 +65,17 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
.required(true)
.build();
public static final PropertyDescriptor REGION = new PropertyDescriptor.Builder()
.displayName("Region")
.name("aws-region")
.required(true)
.allowableValues(getAvailableRegions())
.defaultValue(createAllowableValue(Regions.DEFAULT_REGION).getValue())
.build();
public static final Relationship REL_ORIGINAL = new Relationship.Builder()
.name("original")
.description("Upon successful completion, the original FlowFile will be routed to this relationship.")
.autoTerminateDefault(true)
.build();
@Override
public void migrateProperties(final PropertyConfiguration config) {
config.renameProperty("aws-region", REGION.getName());
}
protected static final List<PropertyDescriptor> PROPERTIES = List.of(
MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE,
REGION,
@ -84,10 +86,9 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
private final static ObjectMapper MAPPER = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.findAndAddModules()
.build();
private static final Set<Relationship> relationships = Set.of(REL_ORIGINAL,
REL_SUCCESS,
REL_FAILURE);
private static final Set<Relationship> relationships = Set.of(REL_ORIGINAL, REL_SUCCESS, REL_FAILURE);
@Override
public Set<Relationship> getRelationships() {
@ -105,14 +106,14 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
if (flowFile == null && !context.getProperty(JSON_PAYLOAD).isSet()) {
return;
}
final RESPONSE response;
final R response;
FlowFile childFlowFile;
try {
response = sendRequest(buildRequest(session, context, flowFile), context, flowFile);
childFlowFile = writeToFlowFile(session, flowFile, response);
postProcessFlowFile(context, session, childFlowFile, response);
childFlowFile = postProcessFlowFile(context, session, childFlowFile, response);
session.transfer(childFlowFile, REL_SUCCESS);
} catch (Exception e) {
} catch (final Exception e) {
if (flowFile != null) {
session.transfer(flowFile, REL_FAILURE);
}
@ -125,26 +126,21 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
}
protected void postProcessFlowFile(ProcessContext context, ProcessSession session, FlowFile flowFile, RESPONSE response) {
protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, final FlowFile flowFile, final R response) {
final String awsTaskId = getAwsTaskId(context, response, flowFile);
flowFile = session.putAttribute(flowFile, TASK_ID.getName(), awsTaskId);
flowFile = session.putAttribute(flowFile, MIME_TYPE.key(), "application/json");
FlowFile processedFlowFile = session.putAttribute(flowFile, TASK_ID.getName(), awsTaskId);
processedFlowFile = session.putAttribute(processedFlowFile, MIME_TYPE.key(), "application/json");
getLogger().debug("AWS ML Task [{}] started", awsTaskId);
return processedFlowFile;
}
protected REQUEST buildRequest(ProcessSession session, ProcessContext context, FlowFile flowFile) throws JsonProcessingException {
return MAPPER.readValue(getPayload(session, context, flowFile), getAwsRequestClass(context, flowFile));
protected Q buildRequest(final ProcessSession session, final ProcessContext context, final FlowFile flowFile) throws JsonProcessingException {
return (Q) MAPPER.readValue(getPayload(session, context, flowFile), getAwsRequestBuilderClass(context, flowFile)).build();
}
@Override
protected T createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
throw new UnsupportedOperationException();
}
protected FlowFile writeToFlowFile(ProcessSession session, FlowFile flowFile, RESPONSE response) {
protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final R response) {
FlowFile childFlowFile = flowFile == null ? session.create() : session.create(flowFile);
childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response));
childFlowFile = session.write(childFlowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
return childFlowFile;
}
@ -156,7 +152,7 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
}
}
private String getPayload(ProcessSession session, ProcessContext context, FlowFile flowFile) {
private String getPayload(final ProcessSession session, final ProcessContext context, final FlowFile flowFile) {
String payloadPropertyValue = context.getProperty(JSON_PAYLOAD).evaluateAttributeExpressions(flowFile).getValue();
if (payloadPropertyValue == null) {
payloadPropertyValue = readFlowFile(session, flowFile);
@ -164,9 +160,9 @@ public abstract class AwsMachineLearningJobStarter<T extends AmazonWebServiceCli
return payloadPropertyValue;
}
abstract protected RESPONSE sendRequest(REQUEST request, ProcessContext context, FlowFile flowFile) throws JsonProcessingException;
abstract protected R sendRequest(Q request, ProcessContext context, FlowFile flowFile) throws JsonProcessingException;
abstract protected Class<? extends REQUEST> getAwsRequestClass(ProcessContext context, FlowFile flowFile);
abstract protected Class<? extends B> getAwsRequestBuilderClass(ProcessContext context, FlowFile flowFile);
abstract protected String getAwsTaskId(ProcessContext context, RESPONSE response, FlowFile flowFile);
abstract protected String getAwsTaskId(ProcessContext context, R response, FlowFile flowFile);
}

View File

@ -17,33 +17,30 @@
package org.apache.nifi.processors.aws.ml;
import com.amazonaws.AmazonWebServiceClient;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.ResponseMetadata;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.http.SdkHttpMetadata;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.migration.PropertyConfiguration;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor;
import org.apache.nifi.processors.aws.v2.AbstractAwsSyncProcessor;
import software.amazon.awssdk.awscore.AwsResponse;
import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder;
import software.amazon.awssdk.core.SdkClient;
import java.util.List;
import java.util.Set;
import static org.apache.nifi.expression.ExpressionLanguageScope.FLOWFILE_ATTRIBUTES;
public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebServiceClient>
extends AbstractAWSCredentialsProviderProcessor<T> {
public abstract class AbstractAwsMachineLearningJobStatusProcessor<
T extends SdkClient,
U extends AwsSyncClientBuilder<U, T> & AwsClientBuilder<U, T>>
extends AbstractAwsSyncProcessor<T, U> {
public static final String AWS_TASK_OUTPUT_LOCATION = "outputLocation";
public static final PropertyDescriptor MANDATORY_AWS_CREDENTIALS_PROVIDER_SERVICE =
new PropertyDescriptor.Builder().fromPropertyDescriptor(AWS_CREDENTIALS_PROVIDER_SERVICE)
@ -81,13 +78,6 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
.description("The job failed, the original FlowFile will be routed to this relationship.")
.autoTerminateDefault(true)
.build();
public static final PropertyDescriptor REGION = new PropertyDescriptor.Builder()
.displayName("Region")
.name("aws-region")
.required(true)
.allowableValues(getAvailableRegions())
.defaultValue(createAllowableValue(Regions.DEFAULT_REGION).getValue())
.build();
public static final String FAILURE_REASON_ATTRIBUTE = "failure.reason";
protected static final List<PropertyDescriptor> PROPERTIES = List.of(
TASK_ID,
@ -99,18 +89,14 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
PROXY_CONFIGURATION_SERVICE);
private static final ObjectMapper MAPPER = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.findAndAddModules()
.build();
static {
SimpleModule awsResponseModule = new SimpleModule();
awsResponseModule.addDeserializer(ResponseMetadata.class, new AwsResponseMetadataDeserializer());
SimpleModule sdkHttpModule = new SimpleModule();
awsResponseModule.addDeserializer(SdkHttpMetadata.class, new SdkHttpMetadataDeserializer());
MAPPER.registerModule(awsResponseModule);
MAPPER.registerModule(sdkHttpModule);
@Override
public void migrateProperties(final PropertyConfiguration config) {
config.renameProperty("aws-region", REGION.getName());
}
@Override
public Set<Relationship> getRelationships() {
return relationships;
@ -129,13 +115,7 @@ public abstract class AwsMachineLearningJobStatusProcessor<T extends AmazonWebSe
return PROPERTIES;
}
@Override
protected T createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
throw new UnsupportedOperationException();
}
protected void writeToFlowFile(ProcessSession session, FlowFile flowFile, Object response) {
session.write(flowFile, out -> MAPPER.writeValue(out, response));
protected FlowFile writeToFlowFile(final ProcessSession session, final FlowFile flowFile, final AwsResponse response) {
return session.write(flowFile, out -> MAPPER.writeValue(out, response.toBuilder()));
}
}

View File

@ -125,22 +125,6 @@
<groupId>org.apache.nifi</groupId>
<artifactId>nifi-schema-registry-service-api</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-translate</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-polly</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-transcribe</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-textract</artifactId>
</dependency>
<dependency>
<groupId>org.bouncycastle</groupId>
<artifactId>bcprov-jdk18on</artifactId>

View File

@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.polly;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.polly.AmazonPollyClient;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
import com.amazonaws.services.polly.model.TaskStatus;
import com.amazonaws.services.textract.model.ThrottlingException;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
@ -34,9 +25,17 @@ import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.PollyClientBuilder;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
import software.amazon.awssdk.services.polly.model.TaskStatus;
import java.util.HashSet;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@ -45,10 +44,10 @@ import java.util.regex.Pattern;
@SeeAlso({StartAwsPollyJob.class})
@WritesAttributes({
@WritesAttribute(attribute = "PollyS3OutputBucket", description = "The bucket name where polly output will be located."),
@WritesAttribute(attribute = "PollyS3OutputKey", description = "Object key of polly output."),
@WritesAttribute(attribute = "filename", description = "Object key of polly output."),
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
})
public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonPollyClient> {
public class GetAwsPollyJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<PollyClient, PollyClientBuilder> {
private static final String BUCKET = "bucket";
private static final String KEY = "key";
private static final Pattern S3_PATH = Pattern.compile("https://s3.*amazonaws.com/(?<" + BUCKET + ">[^/]+)/(?<" + KEY + ">.*)");
@ -56,65 +55,66 @@ public class GetAwsPollyJobStatus extends AwsMachineLearningJobStatusProcessor<A
private static final String AWS_S3_KEY = "filename";
@Override
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonPollyClient) AmazonPollyClient.builder()
.withCredentials(credentialsProvider)
.withRegion(context.getProperty(REGION).getValue())
.withEndpointConfiguration(endpointConfiguration)
.withClientConfiguration(config)
.build();
public Set<Relationship> getRelationships() {
final Set<Relationship> parentRelationships = new HashSet<>(super.getRelationships());
parentRelationships.remove(REL_THROTTLED);
return Set.copyOf(parentRelationships);
}
@Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
return PollyClient.builder();
}
@Override
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
GetSpeechSynthesisTaskResult speechSynthesisTask;
final GetSpeechSynthesisTaskResponse speechSynthesisTask;
try {
speechSynthesisTask = getSynthesisTask(context, flowFile);
} catch (ThrottlingException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
return;
} catch (Exception e) {
} catch (final Exception e) {
getLogger().warn("Failed to get Polly Job status", e);
session.transfer(flowFile, REL_FAILURE);
return;
}
TaskStatus taskStatus = TaskStatus.fromValue(speechSynthesisTask.getSynthesisTask().getTaskStatus());
final TaskStatus taskStatus = speechSynthesisTask.synthesisTask().taskStatus();
if (taskStatus == TaskStatus.InProgress || taskStatus == TaskStatus.Scheduled) {
session.penalize(flowFile);
if (taskStatus == TaskStatus.IN_PROGRESS || taskStatus == TaskStatus.SCHEDULED) {
flowFile = session.penalize(flowFile);
session.transfer(flowFile, REL_RUNNING);
} else if (taskStatus == TaskStatus.Completed) {
String outputUri = speechSynthesisTask.getSynthesisTask().getOutputUri();
} else if (taskStatus == TaskStatus.COMPLETED) {
final String outputUri = speechSynthesisTask.synthesisTask().outputUri();
Matcher matcher = S3_PATH.matcher(outputUri);
final Matcher matcher = S3_PATH.matcher(outputUri);
if (matcher.find()) {
session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
flowFile = session.putAttribute(flowFile, AWS_S3_BUCKET, matcher.group(BUCKET));
flowFile = session.putAttribute(flowFile, AWS_S3_KEY, matcher.group(KEY));
}
FlowFile childFlowFile = session.create(flowFile);
writeToFlowFile(session, childFlowFile, speechSynthesisTask);
childFlowFile = writeToFlowFile(session, childFlowFile, speechSynthesisTask);
childFlowFile = session.putAttribute(childFlowFile, AWS_TASK_OUTPUT_LOCATION, outputUri);
session.transfer(flowFile, REL_ORIGINAL);
session.transfer(childFlowFile, REL_SUCCESS);
getLogger().info("Amazon Polly Task Completed {}", flowFile);
} else if (taskStatus == TaskStatus.Failed) {
final String failureReason = speechSynthesisTask.getSynthesisTask().getTaskStatusReason();
} else if (taskStatus == TaskStatus.FAILED) {
final String failureReason = speechSynthesisTask.synthesisTask().taskStatusReason();
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Polly Task Failed {} Reason [{}]", flowFile, failureReason);
} else if (taskStatus == TaskStatus.UNKNOWN_TO_SDK_VERSION) {
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Polly Task Failed {} Reason [Unrecognized job status]", flowFile);
}
}
private GetSpeechSynthesisTaskResult getSynthesisTask(ProcessContext context, FlowFile flowFile) {
String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
GetSpeechSynthesisTaskRequest request = new GetSpeechSynthesisTaskRequest().withTaskId(taskId);
private GetSpeechSynthesisTaskResponse getSynthesisTask(final ProcessContext context, final FlowFile flowFile) {
final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
final GetSpeechSynthesisTaskRequest request = GetSpeechSynthesisTaskRequest.builder().taskId(taskId).build();
return getClient(context).getSpeechSynthesisTask(request);
}
}

View File

@ -17,49 +17,45 @@
package org.apache.nifi.processors.aws.ml.polly;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.polly.AmazonPollyClient;
import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskRequest;
import com.amazonaws.services.polly.model.StartSpeechSynthesisTaskResult;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.PollyClientBuilder;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Polly"})
@CapabilityDescription("Trigger a AWS Polly job. It should be followed by GetAwsPollyJobStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsPollyJobStatus")
})
@SeeAlso({GetAwsPollyJobStatus.class})
public class StartAwsPollyJob extends AwsMachineLearningJobStarter<AmazonPollyClient, StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskResult> {
public class StartAwsPollyJob extends AbstractAwsMachineLearningJobStarter<
StartSpeechSynthesisTaskRequest, StartSpeechSynthesisTaskRequest.Builder, StartSpeechSynthesisTaskResponse, PollyClient, PollyClientBuilder> {
@Override
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonPollyClient) AmazonPollyClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
protected PollyClientBuilder createClientBuilder(final ProcessContext context) {
return PollyClient.builder();
}
@Override
protected StartSpeechSynthesisTaskResult sendRequest(StartSpeechSynthesisTaskRequest request, ProcessContext context, FlowFile flowFile) {
protected StartSpeechSynthesisTaskResponse sendRequest(final StartSpeechSynthesisTaskRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startSpeechSynthesisTask(request);
}
@Override
protected Class<? extends StartSpeechSynthesisTaskRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
return StartSpeechSynthesisTaskRequest.class;
protected Class<? extends StartSpeechSynthesisTaskRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
return StartSpeechSynthesisTaskRequest.serializableBuilderClass();
}
@Override
protected String getAwsTaskId(ProcessContext context, StartSpeechSynthesisTaskResult startSpeechSynthesisTaskResult, FlowFile flowFile) {
return startSpeechSynthesisTaskResult.getSynthesisTask().getTaskId();
protected String getAwsTaskId(final ProcessContext context, final StartSpeechSynthesisTaskResponse startSpeechSynthesisTaskResponse, final FlowFile flowFile) {
return startSpeechSynthesisTaskResponse.synthesisTask().taskId();
}
}

View File

@ -17,48 +17,61 @@
package org.apache.nifi.processors.aws.ml.textract;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.AmazonTextractClient;
import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
import com.amazonaws.services.textract.model.GetDocumentTextDetectionRequest;
import com.amazonaws.services.textract.model.GetExpenseAnalysisRequest;
import com.amazonaws.services.textract.model.JobStatus;
import com.amazonaws.services.textract.model.ThrottlingException;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationContext;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.components.Validator;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.TextractClientBuilder;
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.JobStatus;
import software.amazon.awssdk.services.textract.model.TextractResponse;
import software.amazon.awssdk.services.textract.model.ThrottlingException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
@CapabilityDescription("Retrieves the current status of an AWS Textract job.")
@SeeAlso({StartAwsTextractJob.class})
public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTextractClient> {
public class GetAwsTextractJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TextractClient, TextractClientBuilder> {
public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
@Override
public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
} else if (TextractType.TEXTRACT_TYPES.contains(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
} else {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
}
}
};
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
.name("textract-type")
.displayName("Textract Type")
.required(true)
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
.allowableValues(TextractType.TEXTRACT_TYPES)
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
.defaultValue(DOCUMENT_ANALYSIS.getType())
.addValidator(StandardValidators.NON_EMPTY_VALIDATOR)
.defaultValue(String.format("${%s}", TEXTRACT_TYPE_ATTRIBUTE))
.addValidator(TEXTRACT_TYPE_VALIDATOR)
.build();
private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES =
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
@ -69,30 +82,24 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
}
@Override
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonTextractClient) AmazonTextractClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.withCredentials(credentialsProvider)
.build();
protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
return TextractClient.builder();
}
@Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
final String textractType = context.getProperty(TEXTRACT_TYPE).evaluateAttributeExpressions(flowFile).getValue();
String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
final String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
try {
JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
final JobStatus jobStatus = getTaskStatus(TextractType.fromString(textractType), getClient(context), awsTaskId);
if (JobStatus.SUCCEEDED == jobStatus) {
Object task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
writeToFlowFile(session, flowFile, task);
final TextractResponse task = getTask(TextractType.fromString(textractType), getClient(context), awsTaskId);
flowFile = writeToFlowFile(session, flowFile, task);
session.transfer(flowFile, REL_SUCCESS);
} else if (JobStatus.IN_PROGRESS == jobStatus) {
session.transfer(flowFile, REL_RUNNING);
@ -101,29 +108,31 @@ public class GetAwsTextractJobStatus extends AwsMachineLearningJobStatusProcesso
} else if (JobStatus.FAILED == jobStatus) {
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Amazon Textract Task [{}] Failed", awsTaskId);
} else {
throw new IllegalStateException("Unrecognized job status");
}
} catch (ThrottlingException e) {
} catch (final ThrottlingException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
} catch (Exception e) {
} catch (final Exception e) {
getLogger().warn("Failed to get Textract Job status", e);
session.transfer(flowFile, REL_FAILURE);
}
}
private Object getTask(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) {
private TextractResponse getTask(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
return switch (typeOfTextract) {
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId));
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId));
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId));
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build());
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build());
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build());
};
}
private JobStatus getTaskStatus(TextractType typeOfTextract, AmazonTextractClient client, String awsTaskId) {
private JobStatus getTaskStatus(final TextractType typeOfTextract, final TextractClient client, final String awsTaskId) {
return switch (typeOfTextract) {
case DOCUMENT_ANALYSIS -> JobStatus.fromValue(client.getDocumentAnalysis(new GetDocumentAnalysisRequest().withJobId(awsTaskId)).getJobStatus());
case DOCUMENT_TEXT_DETECTION -> JobStatus.fromValue(client.getDocumentTextDetection(new GetDocumentTextDetectionRequest().withJobId(awsTaskId)).getJobStatus());
case EXPENSE_ANALYSIS -> JobStatus.fromValue(client.getExpenseAnalysis(new GetExpenseAnalysisRequest().withJobId(awsTaskId)).getJobStatus());
case DOCUMENT_ANALYSIS -> client.getDocumentAnalysis(GetDocumentAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
case DOCUMENT_TEXT_DETECTION -> client.getDocumentTextDetection(GetDocumentTextDetectionRequest.builder().jobId(awsTaskId).build()).jobStatus();
case EXPENSE_ANALYSIS -> client.getExpenseAnalysis(GetExpenseAnalysisRequest.builder().jobId(awsTaskId).build()).jobStatus();
};
}
}

View File

@ -17,31 +17,26 @@
package org.apache.nifi.processors.aws.ml.textract;
import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.AmazonWebServiceResult;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.AmazonTextractClient;
import com.amazonaws.services.textract.model.StartDocumentAnalysisRequest;
import com.amazonaws.services.textract.model.StartDocumentAnalysisResult;
import com.amazonaws.services.textract.model.StartDocumentTextDetectionRequest;
import com.amazonaws.services.textract.model.StartDocumentTextDetectionResult;
import com.amazonaws.services.textract.model.StartExpenseAnalysisRequest;
import com.amazonaws.services.textract.model.StartExpenseAnalysisResult;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.components.PropertyDescriptor;
import org.apache.nifi.components.ValidationContext;
import org.apache.nifi.components.ValidationResult;
import org.apache.nifi.components.Validator;
import org.apache.nifi.expression.ExpressionLanguageScope;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.TextractClientBuilder;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
import software.amazon.awssdk.services.textract.model.TextractRequest;
import software.amazon.awssdk.services.textract.model.TextractResponse;
import java.util.Collections;
import java.util.List;
@ -52,28 +47,23 @@ import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_A
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Textract"})
@CapabilityDescription("Trigger a AWS Textract job. It should be followed by GetAwsTextractJobStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTextractJobStatus"),
@WritesAttribute(attribute = "awsTextractType", description = "The selected Textract type, which can be used in GetAwsTextractJobStatus")
})
@SeeAlso({GetAwsTextractJobStatus.class})
public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonTextractClient, AmazonWebServiceRequest, AmazonWebServiceResult> {
public static final Validator TEXTRACT_TYPE_VALIDATOR = new Validator() {
@Override
public ValidationResult validate(final String subject, final String value, final ValidationContext context) {
if (context.isExpressionLanguageSupported(subject) && context.isExpressionLanguagePresent(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Expression Language Present").valid(true).build();
} else if (TextractType.TEXTRACT_TYPES.contains(value)) {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Supported Value.").valid(true).build();
} else {
return new ValidationResult.Builder().subject(subject).input(value).explanation("Not a supported value, flow file attribute or context parameter.").valid(false).build();
}
}
};
public class StartAwsTextractJob extends AbstractAwsMachineLearningJobStarter<
TextractRequest, TextractRequest.Builder, TextractResponse, TextractClient, TextractClientBuilder> {
public static final String TEXTRACT_TYPE_ATTRIBUTE = "awsTextractType";
public static final PropertyDescriptor TEXTRACT_TYPE = new PropertyDescriptor.Builder()
.name("textract-type")
.displayName("Textract Type")
.required(true)
.description("Supported values: \"Document Analysis\", \"Document Text Detection\", \"Expense Analysis\"")
.expressionLanguageSupported(ExpressionLanguageScope.FLOWFILE_ATTRIBUTES)
.defaultValue(DOCUMENT_ANALYSIS.type)
.addValidator(TEXTRACT_TYPE_VALIDATOR)
.allowableValues(TextractType.TEXTRACT_TYPES)
.defaultValue(DOCUMENT_ANALYSIS.getType())
.build();
private static final List<PropertyDescriptor> TEXTRACT_PROPERTIES =
Collections.unmodifiableList(Stream.concat(PROPERTIES.stream(), Stream.of(TEXTRACT_TYPE)).collect(Collectors.toList()));
@ -84,24 +74,13 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonText
}
@Override
protected void postProcessFlowFile(ProcessContext context, ProcessSession session, FlowFile flowFile, AmazonWebServiceResult response) {
super.postProcessFlowFile(context, session, flowFile, response);
protected TextractClientBuilder createClientBuilder(final ProcessContext context) {
return TextractClient.builder();
}
@Override
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonTextractClient) AmazonTextractClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
}
@Override
protected AmazonWebServiceResult sendRequest(AmazonWebServiceRequest request, ProcessContext context, FlowFile flowFile) {
TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
protected TextractResponse sendRequest(final TextractRequest request, final ProcessContext context, final FlowFile flowFile) {
TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (textractType) {
case DOCUMENT_ANALYSIS -> getClient(context).startDocumentAnalysis((StartDocumentAnalysisRequest) request);
case DOCUMENT_TEXT_DETECTION -> getClient(context).startDocumentTextDetection((StartDocumentTextDetectionRequest) request);
@ -110,22 +89,28 @@ public class StartAwsTextractJob extends AwsMachineLearningJobStarter<AmazonText
}
@Override
protected Class<? extends AmazonWebServiceRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
protected Class<? extends TextractRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
final TextractType typeOfTextract = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (typeOfTextract) {
case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.class;
case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.class;
case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.class;
case DOCUMENT_ANALYSIS -> StartDocumentAnalysisRequest.serializableBuilderClass();
case DOCUMENT_TEXT_DETECTION -> StartDocumentTextDetectionRequest.serializableBuilderClass();
case EXPENSE_ANALYSIS -> StartExpenseAnalysisRequest.serializableBuilderClass();
};
}
@Override
protected String getAwsTaskId(ProcessContext context, AmazonWebServiceResult amazonWebServiceResult, FlowFile flowFile) {
final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).evaluateAttributeExpressions(flowFile).getValue());
protected String getAwsTaskId(final ProcessContext context, final TextractResponse textractResponse, final FlowFile flowFile) {
final TextractType textractType = TextractType.fromString(context.getProperty(TEXTRACT_TYPE.getName()).getValue());
return switch (textractType) {
case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResult) amazonWebServiceResult).getJobId();
case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResult) amazonWebServiceResult).getJobId();
case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResult) amazonWebServiceResult).getJobId();
case DOCUMENT_ANALYSIS -> ((StartDocumentAnalysisResponse) textractResponse).jobId();
case DOCUMENT_TEXT_DETECTION -> ((StartDocumentTextDetectionResponse) textractResponse).jobId();
case EXPENSE_ANALYSIS -> ((StartExpenseAnalysisResponse) textractResponse).jobId();
};
}
@Override
protected FlowFile postProcessFlowFile(final ProcessContext context, final ProcessSession session, FlowFile flowFile, final TextractResponse response) {
flowFile = super.postProcessFlowFile(context, session, flowFile, response);
return session.putAttribute(flowFile, TEXTRACT_TYPE_ATTRIBUTE, context.getProperty(TEXTRACT_TYPE).getValue());
}
}

View File

@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.transcribe;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.model.ThrottlingException;
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
@ -35,7 +26,13 @@ import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
import software.amazon.awssdk.services.transcribe.model.LimitExceededException;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
@CapabilityDescription("Retrieves the current status of an AWS Transcribe job.")
@ -43,55 +40,50 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
@WritesAttributes({
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
})
public class GetAwsTranscribeJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTranscribeClient> {
public class GetAwsTranscribeJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TranscribeClient, TranscribeClientBuilder> {
@Override
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withEndpointConfiguration(endpointConfiguration)
.withClientConfiguration(config)
.build();
protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
return TranscribeClient.builder();
}
@Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
try {
GetTranscriptionJobResult job = getJob(context, flowFile);
TranscriptionJobStatus jobStatus = TranscriptionJobStatus.fromValue(job.getTranscriptionJob().getTranscriptionJobStatus());
final GetTranscriptionJobResponse job = getJob(context, flowFile);
final TranscriptionJobStatus status = job.transcriptionJob().transcriptionJobStatus();
if (TranscriptionJobStatus.COMPLETED == jobStatus) {
writeToFlowFile(session, flowFile, job);
session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.getTranscriptionJob().getTranscript().getTranscriptFileUri());
if (TranscriptionJobStatus.COMPLETED == status) {
flowFile = writeToFlowFile(session, flowFile, job);
flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.transcriptionJob().transcript().transcriptFileUri());
session.transfer(flowFile, REL_SUCCESS);
} else if (TranscriptionJobStatus.IN_PROGRESS == jobStatus) {
} else if (TranscriptionJobStatus.IN_PROGRESS == status || TranscriptionJobStatus.QUEUED == status) {
session.transfer(flowFile, REL_RUNNING);
} else if (TranscriptionJobStatus.FAILED == jobStatus) {
final String failureReason = job.getTranscriptionJob().getFailureReason();
session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
} else if (TranscriptionJobStatus.FAILED == status) {
final String failureReason = job.transcriptionJob().failureReason();
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
session.transfer(flowFile, REL_FAILURE);
getLogger().error("Transcribe Task Failed {} Reason [{}]", flowFile, failureReason);
} else {
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, "Unrecognized job status");
throw new IllegalStateException("Unrecognized job status");
}
} catch (ThrottlingException e) {
} catch (final LimitExceededException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
return;
} catch (Exception e) {
} catch (final Exception e) {
getLogger().warn("Failed to get Transcribe Job status", e);
session.transfer(flowFile, REL_FAILURE);
return;
}
}
private GetTranscriptionJobResult getJob(ProcessContext context, FlowFile flowFile) {
String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
GetTranscriptionJobRequest request = new GetTranscriptionJobRequest().withTranscriptionJobName(taskId);
private GetTranscriptionJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
final GetTranscriptionJobRequest request = GetTranscriptionJobRequest.builder().transcriptionJobName(taskId).build();
return getClient(context).getTranscriptionJob(request);
}
}

View File

@ -17,48 +17,45 @@
package org.apache.nifi.processors.aws.ml.transcribe;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
import com.amazonaws.services.transcribe.model.StartTranscriptionJobRequest;
import com.amazonaws.services.transcribe.model.StartTranscriptionJobResult;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.TranscribeClientBuilder;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Transcribe"})
@CapabilityDescription("Trigger a AWS Transcribe job. It should be followed by GetAwsTranscribeStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranscribeJobStatus")
})
@SeeAlso({GetAwsTranscribeJobStatus.class})
public class StartAwsTranscribeJob extends AwsMachineLearningJobStarter<AmazonTranscribeClient, StartTranscriptionJobRequest, StartTranscriptionJobResult> {
public class StartAwsTranscribeJob extends AbstractAwsMachineLearningJobStarter<
StartTranscriptionJobRequest, StartTranscriptionJobRequest.Builder, StartTranscriptionJobResponse, TranscribeClient, TranscribeClientBuilder> {
@Override
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonTranscribeClient) AmazonTranscribeClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.withCredentials(credentialsProvider)
.build();
protected TranscribeClientBuilder createClientBuilder(final ProcessContext context) {
return TranscribeClient.builder();
}
@Override
protected StartTranscriptionJobResult sendRequest(StartTranscriptionJobRequest request, ProcessContext context, FlowFile flowFile) {
protected StartTranscriptionJobResponse sendRequest(final StartTranscriptionJobRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startTranscriptionJob(request);
}
@Override
protected Class<? extends StartTranscriptionJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
return StartTranscriptionJobRequest.class;
protected Class<? extends StartTranscriptionJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
return StartTranscriptionJobRequest.serializableBuilderClass();
}
@Override
protected String getAwsTaskId(ProcessContext context, StartTranscriptionJobResult startTranscriptionJobResult, FlowFile flowFile) {
return startTranscriptionJobResult.getTranscriptionJob().getTranscriptionJobName();
protected String getAwsTaskId(final ProcessContext context, final StartTranscriptionJobResponse startTranscriptionJobResponse, final FlowFile flowFile) {
return startTranscriptionJobResponse.transcriptionJob().transcriptionJobName();
}
}

View File

@ -17,15 +17,6 @@
package org.apache.nifi.processors.aws.ml.translate;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.model.ThrottlingException;
import com.amazonaws.services.translate.AmazonTranslateClient;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
import com.amazonaws.services.translate.model.JobStatus;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
@ -34,8 +25,15 @@ import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processor.ProcessSession;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.TranslateClientBuilder;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
import software.amazon.awssdk.services.translate.model.JobStatus;
import software.amazon.awssdk.services.translate.model.LimitExceededException;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
@CapabilityDescription("Retrieves the current status of an AWS Translate job.")
@ -43,54 +41,66 @@ import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor;
@WritesAttributes({
@WritesAttribute(attribute = "outputLocation", description = "S3 path-style output location of the result.")
})
public class GetAwsTranslateJobStatus extends AwsMachineLearningJobStatusProcessor<AmazonTranslateClient> {
public class GetAwsTranslateJobStatus extends AbstractAwsMachineLearningJobStatusProcessor<TranslateClient, TranslateClientBuilder> {
@Override
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonTranslateClient) AmazonTranslateClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
return TranslateClient.builder();
}
@Override
public void onTrigger(ProcessContext context, ProcessSession session) throws ProcessException {
public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException {
FlowFile flowFile = session.get();
if (flowFile == null) {
return;
}
String awsTaskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
try {
DescribeTextTranslationJobResult describeTextTranslationJobResult = getStatusString(context, awsTaskId);
JobStatus status = JobStatus.fromValue(describeTextTranslationJobResult.getTextTranslationJobProperties().getJobStatus());
final DescribeTextTranslationJobResponse job = getJob(context, flowFile);
final JobStatus status = job.textTranslationJobProperties().jobStatus();
if (status == JobStatus.IN_PROGRESS || status == JobStatus.SUBMITTED) {
writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
session.penalize(flowFile);
session.transfer(flowFile, REL_RUNNING);
} else if (status == JobStatus.COMPLETED) {
session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, describeTextTranslationJobResult.getTextTranslationJobProperties().getOutputDataConfig().getS3Uri());
writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
session.transfer(flowFile, REL_SUCCESS);
} else if (status == JobStatus.FAILED || status == JobStatus.COMPLETED_WITH_ERROR) {
writeToFlowFile(session, flowFile, describeTextTranslationJobResult);
session.transfer(flowFile, REL_FAILURE);
flowFile = writeToFlowFile(session, flowFile, job);
final Relationship transferRelationship;
String failureReason = null;
switch (status) {
case IN_PROGRESS:
case SUBMITTED:
case STOP_REQUESTED:
flowFile = session.penalize(flowFile);
transferRelationship = REL_RUNNING;
break;
case COMPLETED:
flowFile = session.putAttribute(flowFile, AWS_TASK_OUTPUT_LOCATION, job.textTranslationJobProperties().outputDataConfig().s3Uri());
transferRelationship = REL_SUCCESS;
break;
case FAILED:
case COMPLETED_WITH_ERROR:
failureReason = job.textTranslationJobProperties().message();
transferRelationship = REL_FAILURE;
break;
case STOPPED:
failureReason = String.format("Job [%s] is stopped", job.textTranslationJobProperties().jobId());
transferRelationship = REL_FAILURE;
break;
default:
failureReason = "Unknown Job Status";
transferRelationship = REL_FAILURE;
}
} catch (ThrottlingException e) {
if (failureReason != null) {
flowFile = session.putAttribute(flowFile, FAILURE_REASON_ATTRIBUTE, failureReason);
}
session.transfer(flowFile, transferRelationship);
} catch (final LimitExceededException e) {
getLogger().info("Request Rate Limit exceeded", e);
session.transfer(flowFile, REL_THROTTLED);
} catch (Exception e) {
getLogger().warn("Failed to get Polly Job status", e);
} catch (final Exception e) {
getLogger().warn("Failed to get Translate Job status", e);
session.transfer(flowFile, REL_FAILURE);
}
}
private DescribeTextTranslationJobResult getStatusString(ProcessContext context, String awsTaskId) {
DescribeTextTranslationJobRequest request = new DescribeTextTranslationJobRequest().withJobId(awsTaskId);
DescribeTextTranslationJobResult translationJobsResult = getClient(context).describeTextTranslationJob(request);
return translationJobsResult;
private DescribeTextTranslationJobResponse getJob(final ProcessContext context, final FlowFile flowFile) {
final String taskId = context.getProperty(TASK_ID).evaluateAttributeExpressions(flowFile).getValue();
final DescribeTextTranslationJobRequest request = DescribeTextTranslationJobRequest.builder().jobId(taskId).build();
return getClient(context).describeTextTranslationJob(request);
}
}

View File

@ -17,47 +17,44 @@
package org.apache.nifi.processors.aws.ml.translate;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.translate.AmazonTranslateClient;
import com.amazonaws.services.translate.model.StartTextTranslationJobRequest;
import com.amazonaws.services.translate.model.StartTextTranslationJobResult;
import org.apache.nifi.annotation.behavior.WritesAttribute;
import org.apache.nifi.annotation.behavior.WritesAttributes;
import org.apache.nifi.annotation.documentation.CapabilityDescription;
import org.apache.nifi.annotation.documentation.SeeAlso;
import org.apache.nifi.annotation.documentation.Tags;
import org.apache.nifi.flowfile.FlowFile;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStarter;
import org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStarter;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.TranslateClientBuilder;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
@Tags({"Amazon", "AWS", "ML", "Machine Learning", "Translate"})
@CapabilityDescription("Trigger a AWS Translate job. It should be followed by GetAwsTranslateJobStatus processor in order to monitor job status.")
@WritesAttributes({
@WritesAttribute(attribute = "awsTaskId", description = "The task ID that can be used to poll for Job completion in GetAwsTranslateJobStatus")
})
@SeeAlso({GetAwsTranslateJobStatus.class})
public class StartAwsTranslateJob extends AwsMachineLearningJobStarter<AmazonTranslateClient, StartTextTranslationJobRequest, StartTextTranslationJobResult> {
public class StartAwsTranslateJob extends AbstractAwsMachineLearningJobStarter<
StartTextTranslationJobRequest, StartTextTranslationJobRequest.Builder, StartTextTranslationJobResponse, TranslateClient, TranslateClientBuilder> {
@Override
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
return (AmazonTranslateClient) AmazonTranslateClient.builder()
.withRegion(context.getProperty(REGION).getValue())
.withCredentials(credentialsProvider)
.withClientConfiguration(config)
.withEndpointConfiguration(endpointConfiguration)
.build();
protected TranslateClientBuilder createClientBuilder(final ProcessContext context) {
return TranslateClient.builder();
}
@Override
protected StartTextTranslationJobResult sendRequest(StartTextTranslationJobRequest request, ProcessContext context, FlowFile flowFile) {
protected StartTextTranslationJobResponse sendRequest(final StartTextTranslationJobRequest request, final ProcessContext context, final FlowFile flowFile) {
return getClient(context).startTextTranslationJob(request);
}
@Override
protected Class<StartTextTranslationJobRequest> getAwsRequestClass(ProcessContext context, FlowFile flowFile) {
return StartTextTranslationJobRequest.class;
protected Class<? extends StartTextTranslationJobRequest.Builder> getAwsRequestBuilderClass(final ProcessContext context, final FlowFile flowFile) {
return StartTextTranslationJobRequest.serializableBuilderClass();
}
protected String getAwsTaskId(ProcessContext context, StartTextTranslationJobResult startTextTranslationJobResult, FlowFile flowFile) {
return startTextTranslationJobResult.getJobId();
protected String getAwsTaskId(final ProcessContext context, final StartTextTranslationJobResponse startTextTranslationJobResponse, final FlowFile flowFile) {
return startTextTranslationJobResponse.jobId();
}
}

View File

@ -17,18 +17,10 @@
package org.apache.nifi.processors.aws.ml.polly;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.polly.AmazonPollyClient;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskRequest;
import com.amazonaws.services.polly.model.GetSpeechSynthesisTaskResult;
import com.amazonaws.services.polly.model.SynthesisTask;
import com.amazonaws.services.polly.model.TaskStatus;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
@ -38,16 +30,20 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.GetSpeechSynthesisTaskResponse;
import software.amazon.awssdk.services.polly.model.SynthesisTask;
import software.amazon.awssdk.services.polly.model.TaskStatus;
import java.util.Collections;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ -57,73 +53,84 @@ public class GetAwsPollyStatusTest {
private static final String PLACEHOLDER_CONTENT = "content";
private TestRunner runner;
@Mock
private AmazonPollyClient mockPollyClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider;
private PollyClient mockPollyClient;
private GetAwsPollyJobStatus processor;
@Captor
private ArgumentCaptor<GetSpeechSynthesisTaskRequest> requestCaptor;
private TestRunner createRunner(final GetAwsPollyJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider");
final GetAwsPollyJobStatus mockGetAwsPollyStatus = new GetAwsPollyJobStatus() {
processor = new GetAwsPollyJobStatus() {
@Override
protected AmazonPollyClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
public PollyClient getClient(final ProcessContext context) {
return mockPollyClient;
}
};
runner = TestRunners.newTestRunner(mockGetAwsPollyStatus);
runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
runner = createRunner(processor);
}
@Test
public void testPollyTaskInProgress() {
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
.withTaskStatus(TaskStatus.InProgress);
taskResult.setSynthesisTask(task);
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).taskStatus(TaskStatus.IN_PROGRESS).build())
.build();
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
}
@Test
public void testPollyTaskCompleted() {
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
.withTaskStatus(TaskStatus.Completed)
.withOutputUri("outputLocationPath");
taskResult.setSynthesisTask(task);
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
final String uri = "https://s3.us-west2.amazonaws.com/bucket/object";
final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder()
.taskId(TEST_TASK_ID)
.taskStatus(TaskStatus.COMPLETED)
.outputUri(uri).build())
.build();
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
runner.assertAllFlowFilesContainAttribute(REL_SUCCESS, AWS_TASK_OUTPUT_LOCATION);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
}
assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
assertEquals(uri, flowFile.getAttribute(GetAwsPollyJobStatus.AWS_TASK_OUTPUT_LOCATION));
assertEquals("bucket", flowFile.getAttribute("PollyS3OutputBucket"));
assertEquals("object", flowFile.getAttribute("filename"));
}
@Test
public void testPollyTaskFailed() {
GetSpeechSynthesisTaskResult taskResult = new GetSpeechSynthesisTaskResult();
SynthesisTask task = new SynthesisTask().withTaskId(TEST_TASK_ID)
.withTaskStatus(TaskStatus.Failed)
.withTaskStatusReason("reasonOfFailure");
taskResult.setSynthesisTask(task);
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(taskResult);
final String failureReason = "reasonOfFailure";
final GetSpeechSynthesisTaskResponse response = GetSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder()
.taskId(TEST_TASK_ID)
.taskStatus(TaskStatus.FAILED)
.taskStatusReason(failureReason).build())
.build();
when(mockPollyClient.getSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
runner.enqueue(PLACEHOLDER_CONTENT, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTaskId());
assertEquals(TEST_TASK_ID, requestCaptor.getValue().taskId());
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
assertEquals(failureReason, flowFile.getAttribute(GetAwsPollyJobStatus.FAILURE_REASON_ATTRIBUTE));
}
}

View File

@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.polly;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.polly.PollyClient;
import software.amazon.awssdk.services.polly.model.Engine;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskRequest;
import software.amazon.awssdk.services.polly.model.StartSpeechSynthesisTaskResponse;
import software.amazon.awssdk.services.polly.model.SynthesisTask;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsPollyJobTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private PollyClient mockPollyClient;
private StartAwsPollyJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartSpeechSynthesisTaskRequest> requestCaptor;
private TestRunner createRunner(final StartAwsPollyJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsPollyJob() {
@Override
public PollyClient getClient(ProcessContext context) {
return mockPollyClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
.engine(Engine.NEURAL)
.text("Text")
.build();
final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
.build();
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
assertEquals("Text", requestCaptor.getValue().text());
assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
}
@Test
public void testSuccessfulAttribute() throws JsonProcessingException {
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
.engine(Engine.NEURAL)
.text("Text")
.build();
final StartSpeechSynthesisTaskResponse response = StartSpeechSynthesisTaskResponse.builder()
.synthesisTask(SynthesisTask.builder().taskId(TEST_TASK_ID).build())
.build();
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsPollyJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartSpeechSynthesisTaskResponse parsedResponse = deserialize(responseData);
assertEquals("Text", requestCaptor.getValue().text());
assertEquals(TEST_TASK_ID, parsedResponse.synthesisTask().taskId());
}
@Test
public void testInvalidJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testServiceFailure() throws JsonProcessingException {
final StartSpeechSynthesisTaskRequest request = StartSpeechSynthesisTaskRequest.builder()
.engine(Engine.NEURAL)
.text("Text")
.build();
when(mockPollyClient.startSpeechSynthesisTask(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartSpeechSynthesisTaskResponse deserialize(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartSpeechSynthesisTaskResponse.serializableBuilderClass()).build();
}
private String serialize(final StartSpeechSynthesisTaskRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -17,17 +17,10 @@
package org.apache.nifi.processors.aws.ml.textract;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.textract.AmazonTextractClient;
import com.amazonaws.services.textract.model.GetDocumentAnalysisRequest;
import com.amazonaws.services.textract.model.GetDocumentAnalysisResult;
import com.amazonaws.services.textract.model.JobStatus;
import com.google.common.collect.ImmutableMap;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
@ -38,13 +31,20 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.GetDocumentAnalysisResponse;
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.GetDocumentTextDetectionResponse;
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.GetExpenseAnalysisResponse;
import software.amazon.awssdk.services.textract.model.JobStatus;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.textract.StartAwsTextractJob.TEXTRACT_TYPE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_THROTTLED;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ -53,64 +53,144 @@ public class GetAwsTextractJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private AmazonTextractClient mockTextractClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider;
private TextractClient mockTextractClient;
private GetAwsTextractJobStatus processor;
@Captor
private ArgumentCaptor<GetDocumentAnalysisRequest> requestCaptor;
private ArgumentCaptor<GetDocumentAnalysisRequest> documentAnalysisCaptor;
@Captor
private ArgumentCaptor<GetExpenseAnalysisRequest> expenseAnalysisRequestCaptor;
@Captor
private ArgumentCaptor<GetDocumentTextDetectionRequest> documentTextDetectionCaptor;
private TestRunner createRunner(final GetAwsTextractJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn("awsCredentialProvider");
final GetAwsTextractJobStatus awsTextractJobStatusGetter = new GetAwsTextractJobStatus() {
processor = new GetAwsTextractJobStatus() {
@Override
protected AmazonTextractClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
public TextractClient getClient(final ProcessContext context) {
return mockTextractClient;
}
};
runner = TestRunners.newTestRunner(awsTextractJobStatusGetter);
runner.addControllerService("awsCredentialProvider", mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, "awsCredentialProvider");
runner = createRunner(processor);
}
@Test
public void testTextractDocAnalysisTaskInProgress() {
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
.withJobStatus(JobStatus.IN_PROGRESS);
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
testTextractDocAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
}
@Test
public void testTextractDocAnalysisTaskComplete() {
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
.withJobStatus(JobStatus.SUCCEEDED);
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.name()));
runner.run();
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
testTextractDocAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
}
@Test
public void testTextractDocAnalysisTaskFailed() {
GetDocumentAnalysisResult taskResult = new GetDocumentAnalysisResult()
.withJobStatus(JobStatus.FAILED);
when(mockTextractClient.getDocumentAnalysis(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID,
TEXTRACT_TYPE.getName(), TextractType.DOCUMENT_ANALYSIS.type));
testTextractDocAnalysis(JobStatus.FAILED, REL_FAILURE);
}
@Test
public void testTextractDocAnalysisTaskPartialSuccess() {
testTextractDocAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
}
@Test
public void testTextractDocAnalysisTaskUnkownStatus() {
testTextractDocAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
}
private void testTextractDocAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
final GetDocumentAnalysisResponse response = GetDocumentAnalysisResponse.builder()
.jobStatus(jobStatus).build();
when(mockTextractClient.getDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
runner.enqueue("content", ImmutableMap.of(
TASK_ID.getName(), TEST_TASK_ID,
StartAwsTextractJob.TEXTRACT_TYPE_ATTRIBUTE, TextractType.DOCUMENT_ANALYSIS.getType()));
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, documentAnalysisCaptor.getValue().jobId());
}
@Test
public void testTextractExpenseAnalysisTaskInProgress() {
testTextractExpenseAnalysis(JobStatus.IN_PROGRESS, REL_RUNNING);
}
@Test
public void testTextractExpenseAnalysisTaskComplete() {
testTextractExpenseAnalysis(JobStatus.SUCCEEDED, REL_SUCCESS);
}
@Test
public void testTextractExpenseAnalysisTaskFailed() {
testTextractExpenseAnalysis(JobStatus.FAILED, REL_FAILURE);
}
@Test
public void testTextractExpenseAnalysisTaskPartialSuccess() {
testTextractExpenseAnalysis(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
}
@Test
public void testTextractExpenseAnalysisTaskUnkownStatus() {
testTextractExpenseAnalysis(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
}
private void testTextractExpenseAnalysis(final JobStatus jobStatus, final Relationship expectedRelationship) {
runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.EXPENSE_ANALYSIS.getType());
final GetExpenseAnalysisResponse response = GetExpenseAnalysisResponse.builder()
.jobStatus(jobStatus).build();
when(mockTextractClient.getExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, expenseAnalysisRequestCaptor.getValue().jobId());
}
@Test
public void testTextractDocumentTextDetectionTaskInProgress() {
testTextractDocumentTextDetection(JobStatus.IN_PROGRESS, REL_RUNNING);
}
@Test
public void testTextractDocumentTextDetectionTaskComplete() {
testTextractDocumentTextDetection(JobStatus.SUCCEEDED, REL_SUCCESS);
}
@Test
public void testTextractDocumentTextDetectionTaskFailed() {
testTextractDocumentTextDetection(JobStatus.FAILED, REL_FAILURE);
}
@Test
public void testTextractDocumentTextDetectionTaskPartialSuccess() {
testTextractDocumentTextDetection(JobStatus.PARTIAL_SUCCESS, REL_THROTTLED);
}
@Test
public void testTextractDocumentTextDetectionTaskUnkownStatus() {
testTextractDocumentTextDetection(JobStatus.UNKNOWN_TO_SDK_VERSION, REL_FAILURE);
}
private void testTextractDocumentTextDetection(final JobStatus jobStatus, final Relationship expectedRelationship) {
runner.setProperty(GetAwsTextractJobStatus.TEXTRACT_TYPE, TextractType.DOCUMENT_TEXT_DETECTION.getType());
final GetDocumentTextDetectionResponse response = GetDocumentTextDetectionResponse.builder()
.jobStatus(jobStatus).build();
when(mockTextractClient.getDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
runner.enqueue("content", ImmutableMap.of(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, documentTextDetectionCaptor.getValue().jobId());
}
}

View File

@ -0,0 +1,342 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.textract;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.textract.TextractClient;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentAnalysisResponse;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionRequest;
import software.amazon.awssdk.services.textract.model.StartDocumentTextDetectionResponse;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisRequest;
import software.amazon.awssdk.services.textract.model.StartExpenseAnalysisResponse;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_ANALYSIS;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.DOCUMENT_TEXT_DETECTION;
import static org.apache.nifi.processors.aws.ml.textract.TextractType.EXPENSE_ANALYSIS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsTextractJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private TextractClient mockTextractClient;
private StartAwsTextractJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartDocumentAnalysisRequest> documentAnalysisCaptor;
@Captor
private ArgumentCaptor<StartExpenseAnalysisRequest> expenseAnalysisRequestCaptor;
@Captor
private ArgumentCaptor<StartDocumentTextDetectionRequest> documentTextDetectionCaptor;
private TestRunner createRunner(final StartAwsTextractJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsTextractJob() {
@Override
public TextractClient getClient(ProcessContext context) {
return mockTextractClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulDocumentAnalysisFlowfileContent() throws JsonProcessingException {
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulDocumentAnalysisAttribute() throws JsonProcessingException {
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentAnalysisResponse response = StartDocumentAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_ANALYSIS.getType());
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentAnalysisResponse parsedResponse = deserializeDARequest(responseData);
assertEquals("Tag", documentAnalysisCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidDocumentAnalysisJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testDocumentAnalysisServiceFailure() throws JsonProcessingException {
final StartDocumentAnalysisRequest request = StartDocumentAnalysisRequest.builder()
.jobTag("Tag")
.build();
when(mockTextractClient.startDocumentAnalysis(documentAnalysisCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testSuccessfulExpenseAnalysisFlowfileContent() throws JsonProcessingException {
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulExpenseAnalysisAttribute() throws JsonProcessingException {
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
.jobTag("Tag")
.build();
final StartExpenseAnalysisResponse response = StartExpenseAnalysisResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartExpenseAnalysisResponse parsedResponse = deserializeEARequest(responseData);
assertEquals("Tag", expenseAnalysisRequestCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidExpenseAnalysisJson() {
final String requestJson = "invalid";
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testExpenseAnalysisServiceFailure() throws JsonProcessingException {
final StartExpenseAnalysisRequest request = StartExpenseAnalysisRequest.builder()
.jobTag("Tag")
.build();
when(mockTextractClient.startExpenseAnalysis(expenseAnalysisRequestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, EXPENSE_ANALYSIS.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testSuccessfulDocumentTextDetectionFlowfileContent() throws JsonProcessingException {
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulDocumentTextDetectionAttribute() throws JsonProcessingException {
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
.jobTag("Tag")
.build();
final StartDocumentTextDetectionResponse response = StartDocumentTextDetectionResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.setProperty(StartAwsTextractJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartDocumentTextDetectionResponse parsedResponse = deserializeDTDRequest(responseData);
assertEquals("Tag", documentTextDetectionCaptor.getValue().jobTag());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidDocumentTextDetectionJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testDocumentTextDetectionServiceFailure() throws JsonProcessingException {
final StartDocumentTextDetectionRequest request = StartDocumentTextDetectionRequest.builder()
.jobTag("Tag")
.build();
when(mockTextractClient.startDocumentTextDetection(documentTextDetectionCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.setProperty(StartAwsTextractJob.TEXTRACT_TYPE, DOCUMENT_TEXT_DETECTION.getType());
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartDocumentTextDetectionResponse deserializeDTDRequest(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartDocumentTextDetectionResponse.serializableBuilderClass()).build();
}
private StartDocumentAnalysisResponse deserializeDARequest(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartDocumentAnalysisResponse.serializableBuilderClass()).build();
}
private StartExpenseAnalysisResponse deserializeEARequest(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartExpenseAnalysisResponse.serializableBuilderClass()).build();
}
private String serialize(final StartDocumentAnalysisRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
private String serialize(final StartExpenseAnalysisRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
private String serialize(final StartDocumentTextDetectionRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -17,19 +17,11 @@
package org.apache.nifi.processors.aws.ml.transcribe;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.transcribe.AmazonTranscribeClient;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobRequest;
import com.amazonaws.services.transcribe.model.GetTranscriptionJobResult;
import com.amazonaws.services.transcribe.model.Transcript;
import com.amazonaws.services.transcribe.model.TranscriptionJob;
import com.amazonaws.services.transcribe.model.TranscriptionJobStatus;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
@ -39,93 +31,118 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.GetTranscriptionJobResponse;
import software.amazon.awssdk.services.transcribe.model.Transcript;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJobStatus;
import java.util.Collections;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.FAILURE_REASON_ATTRIBUTE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class GetAwsTranscribeJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId";
private static final String AWS_CREDENTIAL_PROVIDER_NAME = "awsCredetialProvider";
private static final String TEST_TASK_ID = "testJobId";
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
private static final String REASON_OF_FAILURE = "reasonOfFailure";
private static final String CONTENT_STRING = "content";
private TestRunner runner;
@Mock
private AmazonTranscribeClient mockTranscribeClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider;
private TranscribeClient mockTranscribeClient;
private GetAwsTranscribeJobStatus processor;
@Captor
private ArgumentCaptor<GetTranscriptionJobRequest> requestCaptor;
private TestRunner createRunner(final GetAwsTranscribeJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIAL_PROVIDER_NAME);
final GetAwsTranscribeJobStatus mockPollyFetcher = new GetAwsTranscribeJobStatus() {
processor = new GetAwsTranscribeJobStatus() {
@Override
protected AmazonTranscribeClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
public TranscribeClient getClient(final ProcessContext context) {
return mockTranscribeClient;
}
};
runner = TestRunners.newTestRunner(mockPollyFetcher);
runner.addControllerService(AWS_CREDENTIAL_PROVIDER_NAME, mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIAL_PROVIDER_NAME);
runner = createRunner(processor);
}
@Test
public void testTranscribeTaskInProgress() {
TranscriptionJob task = new TranscriptionJob()
.withTranscriptionJobName(TEST_TASK_ID)
.withTranscriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS);
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
public void testTranscribeJobInProgress() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.transcriptionJobStatus(TranscriptionJobStatus.IN_PROGRESS)
.build();
testTranscribeJob(job, REL_RUNNING);
}
@Test
public void testTranscribeTaskCompleted() {
TranscriptionJob task = new TranscriptionJob()
.withTranscriptionJobName(TEST_TASK_ID)
.withTranscript(new Transcript().withTranscriptFileUri(OUTPUT_LOCATION_PATH))
.withTranscriptionJobStatus(TranscriptionJobStatus.COMPLETED);
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
public void testTranscribeJobQueued() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.transcriptionJobStatus(TranscriptionJobStatus.QUEUED)
.build();
testTranscribeJob(job, REL_RUNNING);
}
@Test
public void testTranscribeJobCompleted() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.transcript(Transcript.builder().transcriptFileUri(OUTPUT_LOCATION_PATH).build())
.transcriptionJobStatus(TranscriptionJobStatus.COMPLETED)
.build();
testTranscribeJob(job, REL_SUCCESS);
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
}
@Test
public void testPollyTaskFailed() {
TranscriptionJob task = new TranscriptionJob()
.withTranscriptionJobName(TEST_TASK_ID)
.withFailureReason(REASON_OF_FAILURE)
.withTranscriptionJobStatus(TranscriptionJobStatus.FAILED);
GetTranscriptionJobResult taskResult = new GetTranscriptionJobResult().withTranscriptionJob(task);
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE);
public void testTranscribeJobFailed() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.failureReason(REASON_OF_FAILURE)
.transcriptionJobStatus(TranscriptionJobStatus.FAILED)
.build();
testTranscribeJob(job, REL_FAILURE);
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getTranscriptionJobName());
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_FAILURE).iterator().next();
assertEquals(REASON_OF_FAILURE, flowFile.getAttribute(FAILURE_REASON_ATTRIBUTE));
}
@Test
public void testTranscribeJobUnrecognized() {
final TranscriptionJob job = TranscriptionJob.builder()
.transcriptionJobName(TEST_TASK_ID)
.failureReason(REASON_OF_FAILURE)
.transcriptionJobStatus(TranscriptionJobStatus.UNKNOWN_TO_SDK_VERSION)
.build();
testTranscribeJob(job, REL_FAILURE);
runner.assertAllFlowFilesContainAttribute(FAILURE_REASON_ATTRIBUTE);
}
private void testTranscribeJob(final TranscriptionJob job, final Relationship expectedRelationship) {
final GetTranscriptionJobResponse response = GetTranscriptionJobResponse.builder().transcriptionJob(job).build();
when(mockTranscribeClient.getTranscriptionJob(requestCaptor.capture())).thenReturn(response);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().transcriptionJobName());
}
}

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.transcribe;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.transcribe.TranscribeClient;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobRequest;
import software.amazon.awssdk.services.transcribe.model.StartTranscriptionJobResponse;
import software.amazon.awssdk.services.transcribe.model.TranscriptionJob;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsTranscribeJobTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private TranscribeClient mockTranscribeClient;
private StartAwsTranscribeJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartTranscriptionJobRequest> requestCaptor;
private TestRunner createRunner(final StartAwsTranscribeJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsTranscribeJob() {
@Override
public TranscribeClient getClient(ProcessContext context) {
return mockTranscribeClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
.transcriptionJobName("Job")
.build();
final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
.transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
.build();
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
}
@Test
public void testSuccessfulAttribute() throws JsonProcessingException {
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
.transcriptionJobName("Job")
.build();
final StartTranscriptionJobResponse response = StartTranscriptionJobResponse.builder()
.transcriptionJob(TranscriptionJob.builder().transcriptionJobName(TEST_TASK_ID).build())
.build();
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTranscribeJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTranscriptionJobResponse parsedResponse = deserialize(responseData);
assertEquals("Job", requestCaptor.getValue().transcriptionJobName());
assertEquals(TEST_TASK_ID, parsedResponse.transcriptionJob().transcriptionJobName());
}
@Test
public void testInvalidJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testServiceFailure() throws JsonProcessingException {
final StartTranscriptionJobRequest request = StartTranscriptionJobRequest.builder()
.transcriptionJobName("Job")
.build();
when(mockTranscribeClient.startTranscriptionJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartTranscriptionJobResponse deserialize(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartTranscriptionJobResponse.serializableBuilderClass()).build();
}
private String serialize(final StartTranscriptionJobRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -17,19 +17,11 @@
package org.apache.nifi.processors.aws.ml.translate;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.regions.Region;
import com.amazonaws.services.translate.AmazonTranslateClient;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobRequest;
import com.amazonaws.services.translate.model.DescribeTextTranslationJobResult;
import com.amazonaws.services.translate.model.JobStatus;
import com.amazonaws.services.translate.model.OutputDataConfig;
import com.amazonaws.services.translate.model.TextTranslationJobProperties;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderService;
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.MockFlowFile;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
@ -39,91 +31,132 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.DescribeTextTranslationJobResponse;
import software.amazon.awssdk.services.translate.model.JobStatus;
import software.amazon.awssdk.services.translate.model.OutputDataConfig;
import software.amazon.awssdk.services.translate.model.TextTranslationJobProperties;
import java.time.Instant;
import java.util.Collections;
import static org.apache.nifi.processors.aws.AbstractAWSCredentialsProviderProcessor.AWS_CREDENTIALS_PROVIDER_SERVICE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.AWS_TASK_OUTPUT_LOCATION;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_RUNNING;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.TASK_ID;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class GetAwsTranslateJobStatusTest {
private static final String TEST_TASK_ID = "testTaskId";
private static final String TEST_TASK_ID = "testJobId";
private static final String OUTPUT_LOCATION_PATH = "outputLocationPath";
private static final String REASON_OF_FAILURE = "reasonOfFailure";
private static final String CONTENT_STRING = "content";
private static final String AWS_CREDENTIALS_PROVIDER_NAME = "awsCredetialProvider";
private static final String OUTPUT_LOCATION_PATH = "outputLocation";
private TestRunner runner;
@Mock
private AmazonTranslateClient mockTranslateClient;
@Mock
private AWSCredentialsProviderService mockAwsCredentialsProvider;
private TranslateClient mockTranslateClient;
private GetAwsTranslateJobStatus processor;
@Captor
private ArgumentCaptor<DescribeTextTranslationJobRequest> requestCaptor;
private TestRunner createRunner(final GetAwsTranslateJobStatus processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
when(mockAwsCredentialsProvider.getIdentifier()).thenReturn(AWS_CREDENTIALS_PROVIDER_NAME);
final GetAwsTranslateJobStatus mockPollyFetcher = new GetAwsTranslateJobStatus() {
processor = new GetAwsTranslateJobStatus() {
@Override
protected AmazonTranslateClient createClient(final ProcessContext context, final AWSCredentialsProvider credentialsProvider, final Region region, final ClientConfiguration config,
final AwsClientBuilder.EndpointConfiguration endpointConfiguration) {
public TranslateClient getClient(final ProcessContext context) {
return mockTranslateClient;
}
};
runner = TestRunners.newTestRunner(mockPollyFetcher);
runner.addControllerService(AWS_CREDENTIALS_PROVIDER_NAME, mockAwsCredentialsProvider);
runner.enableControllerService(mockAwsCredentialsProvider);
runner.setProperty(AWS_CREDENTIALS_PROVIDER_SERVICE, AWS_CREDENTIALS_PROVIDER_NAME);
runner = createRunner(processor);
}
@Test
public void testTranscribeTaskInProgress() {
TextTranslationJobProperties task = new TextTranslationJobProperties()
.withJobId(TEST_TASK_ID)
.withJobStatus(JobStatus.IN_PROGRESS);
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_RUNNING);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
public void testTranslateJobInProgress() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobId(TEST_TASK_ID)
.jobStatus(JobStatus.IN_PROGRESS)
.build();
testTranslateJob(job, REL_RUNNING);
}
@Test
public void testTranscribeTaskCompleted() {
TextTranslationJobProperties task = new TextTranslationJobProperties()
.withJobId(TEST_TASK_ID)
.withOutputDataConfig(new OutputDataConfig().withS3Uri(OUTPUT_LOCATION_PATH))
.withJobStatus(JobStatus.COMPLETED);
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
public void testTranslateSubmitted() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobId(TEST_TASK_ID)
.jobStatus(JobStatus.SUBMITTED)
.build();
testTranslateJob(job, REL_RUNNING);
}
runner.assertAllFlowFilesTransferred(REL_SUCCESS);
@Test
public void testTranslateStopRequested() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobId(TEST_TASK_ID)
.jobStatus(JobStatus.STOP_REQUESTED)
.build();
testTranslateJob(job, REL_RUNNING);
}
@Test
public void testTranslateJobCompleted() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.outputDataConfig(OutputDataConfig.builder().s3Uri(OUTPUT_LOCATION_PATH).build())
.submittedTime(Instant.now())
.jobStatus(JobStatus.COMPLETED)
.build();
testTranslateJob(job, REL_SUCCESS);
runner.assertAllFlowFilesContainAttribute(AWS_TASK_OUTPUT_LOCATION);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
final MockFlowFile flowFile = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next();
assertEquals(OUTPUT_LOCATION_PATH, flowFile.getAttribute(AWS_TASK_OUTPUT_LOCATION));
}
@Test
public void testTranscribeTaskFailed() {
TextTranslationJobProperties task = new TextTranslationJobProperties()
.withJobId(TEST_TASK_ID)
.withJobStatus(JobStatus.FAILED);
DescribeTextTranslationJobResult taskResult = new DescribeTextTranslationJobResult().withTextTranslationJobProperties(task);
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(taskResult);
public void testTranslateJobFailed() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.jobStatus(JobStatus.FAILED)
.build();
testTranslateJob(job, REL_FAILURE);
}
@Test
public void testTranslateJobStopped() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.jobStatus(JobStatus.STOPPED)
.build();
testTranslateJob(job, REL_FAILURE);
}
@Test
public void testTranslateJobUnrecognized() {
final TextTranslationJobProperties job = TextTranslationJobProperties.builder()
.jobStatus(TEST_TASK_ID)
.jobStatus(JobStatus.UNKNOWN_TO_SDK_VERSION)
.build();
testTranslateJob(job, REL_FAILURE);
}
private void testTranslateJob(final TextTranslationJobProperties job, final Relationship expectedRelationship) {
final DescribeTextTranslationJobResponse response = DescribeTextTranslationJobResponse.builder().textTranslationJobProperties(job).build();
when(mockTranslateClient.describeTextTranslationJob(requestCaptor.capture())).thenReturn(response);
runner.enqueue(CONTENT_STRING, Collections.singletonMap(TASK_ID.getName(), TEST_TASK_ID));
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().getJobId());
runner.assertAllFlowFilesTransferred(expectedRelationship);
assertEquals(TEST_TASK_ID, requestCaptor.getValue().jobId());
}
}

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.aws.ml.translate;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import org.apache.nifi.processor.ProcessContext;
import org.apache.nifi.processors.aws.testutil.AuthUtils;
import org.apache.nifi.reporting.InitializationException;
import org.apache.nifi.util.TestRunner;
import org.apache.nifi.util.TestRunners;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.services.translate.TranslateClient;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobRequest;
import software.amazon.awssdk.services.translate.model.StartTextTranslationJobResponse;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_FAILURE;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_ORIGINAL;
import static org.apache.nifi.processors.aws.ml.AbstractAwsMachineLearningJobStatusProcessor.REL_SUCCESS;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
public class StartAwsTranslateJobTest {
private static final String TEST_TASK_ID = "testTaskId";
private TestRunner runner;
@Mock
private TranslateClient mockTranslateClient;
private StartAwsTranslateJob processor;
private ObjectMapper objectMapper = JsonMapper.builder()
.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true)
.build();
@Captor
private ArgumentCaptor<StartTextTranslationJobRequest> requestCaptor;
private TestRunner createRunner(final StartAwsTranslateJob processor) {
final TestRunner runner = TestRunners.newTestRunner(processor);
AuthUtils.enableAccessKey(runner, "abcd", "defg");
return runner;
}
@BeforeEach
public void setUp() throws InitializationException {
processor = new StartAwsTranslateJob() {
@Override
public TranslateClient getClient(ProcessContext context) {
return mockTranslateClient;
}
};
runner = createRunner(processor);
}
@Test
public void testSuccessfulFlowfileContent() throws JsonProcessingException {
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
.terminologyNames("Name")
.build();
final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testSuccessfulAttribute() throws JsonProcessingException {
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
.terminologyNames("Name")
.build();
final StartTextTranslationJobResponse response = StartTextTranslationJobResponse.builder()
.jobId(TEST_TASK_ID)
.build();
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenReturn(response);
final String requestJson = serialize(request);
runner.setProperty(StartAwsTranslateJob.JSON_PAYLOAD, "${json.payload}");
final Map<String, String> attributes = new HashMap<>();
attributes.put("json.payload", requestJson);
runner.enqueue("", attributes);
runner.run();
runner.assertTransferCount(REL_SUCCESS, 1);
runner.assertTransferCount(REL_ORIGINAL, 1);
final String responseData = runner.getFlowFilesForRelationship(REL_SUCCESS).iterator().next().getContent();
final StartTextTranslationJobResponse parsedResponse = deserialize(responseData);
assertEquals(Collections.singletonList("Name"), requestCaptor.getValue().terminologyNames());
assertEquals(TEST_TASK_ID, parsedResponse.jobId());
}
@Test
public void testInvalidJson() {
final String requestJson = "invalid";
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
@Test
public void testServiceFailure() throws JsonProcessingException {
final StartTextTranslationJobRequest request = StartTextTranslationJobRequest.builder()
.terminologyNames("Name")
.build();
when(mockTranslateClient.startTextTranslationJob(requestCaptor.capture())).thenThrow(AwsServiceException.builder().message("message").build());
final String requestJson = serialize(request);
runner.enqueue(requestJson);
runner.run();
runner.assertAllFlowFilesTransferred(REL_FAILURE, 1);
}
private StartTextTranslationJobResponse deserialize(final String responseData) throws JsonProcessingException {
return objectMapper.readValue(responseData, StartTextTranslationJobResponse.serializableBuilderClass()).build();
}
private String serialize(final StartTextTranslationJobRequest request) throws JsonProcessingException {
return objectMapper.writeValueAsString(request.toBuilder());
}
}

View File

@ -66,7 +66,6 @@ public abstract class AbstractSQSIT {
queueUrl = response.queueUrl();
}
@AfterAll
public static void shutdown() {
client.close();