[ML] Parse and index inference model (#48016) (#48152)

This adds parsing an inference model as a possible
result of the analytics process. When we do parse such a model
we persist a `TrainedModelConfig` into the inference index
that contains additional metadata derived from the running job.
This commit is contained in:
Benjamin Trent 2019-10-16 15:46:20 -04:00 committed by GitHub
parent 74812f78dd
commit 0dddbb5b42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 420 additions and 228 deletions

View File

@ -30,21 +30,21 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class TrainedModelConfig implements ToXContentObject {
public static final String NAME = "trained_model_doc";
public static final String NAME = "trained_model_config";
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
public static final ParseField CREATED_TIME = new ParseField("created_time");
public static final ParseField MODEL_VERSION = new ParseField("model_version");
public static final ParseField CREATE_TIME = new ParseField("create_time");
public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
@ -55,16 +55,15 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
PARSER.declareField(TrainedModelConfig.Builder::setCreatedTime,
(p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()),
CREATED_TIME,
PARSER.declareField(TrainedModelConfig.Builder::setCreateTime,
(p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()),
CREATE_TIME,
ObjectParser.ValueType.VALUE);
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setDefinition,
(p, c) -> TrainedModelDefinition.fromXContent(p),
DEFINITION);
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
}
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@ -75,30 +74,27 @@ public class TrainedModelConfig implements ToXContentObject {
private final String createdBy;
private final Version version;
private final String description;
private final Instant createdTime;
private final Long modelVersion;
private final String modelType;
private final Map<String, Object> metadata;
private final Instant createTime;
private final TrainedModelDefinition definition;
private final List<String> tags;
private final Map<String, Object> metadata;
TrainedModelConfig(String modelId,
String createdBy,
Version version,
String description,
Instant createdTime,
Long modelVersion,
String modelType,
Instant createTime,
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
this.createdTime = Instant.ofEpochMilli(createdTime.toEpochMilli());
this.modelType = modelType;
this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
this.definition = definition;
this.description = description;
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.modelVersion = modelVersion;
}
public String getModelId() {
@ -117,16 +113,12 @@ public class TrainedModelConfig implements ToXContentObject {
return description;
}
public Instant getCreatedTime() {
return createdTime;
public Instant getCreateTime() {
return createTime;
}
public Long getModelVersion() {
return modelVersion;
}
public String getModelType() {
return modelType;
public List<String> getTags() {
return tags;
}
public Map<String, Object> getMetadata() {
@ -156,18 +148,15 @@ public class TrainedModelConfig implements ToXContentObject {
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
if (createdTime != null) {
builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli());
}
if (modelVersion != null) {
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
}
if (modelType != null) {
builder.field(MODEL_TYPE.getPreferredName(), modelType);
if (createTime != null) {
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
}
if (definition != null) {
builder.field(DEFINITION.getPreferredName(), definition);
}
if (tags != null) {
builder.field(TAGS.getPreferredName(), tags);
}
if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata);
}
@ -189,10 +178,9 @@ public class TrainedModelConfig implements ToXContentObject {
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Objects.equals(createdTime, that.createdTime) &&
Objects.equals(modelVersion, that.modelVersion) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(createTime, that.createTime) &&
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(metadata, that.metadata);
}
@ -201,12 +189,11 @@ public class TrainedModelConfig implements ToXContentObject {
return Objects.hash(modelId,
createdBy,
version,
createdTime,
modelType,
createTime,
definition,
description,
metadata,
modelVersion);
tags,
metadata);
}
@ -216,11 +203,10 @@ public class TrainedModelConfig implements ToXContentObject {
private String createdBy;
private Version version;
private String description;
private Instant createdTime;
private Long modelVersion;
private String modelType;
private Instant createTime;
private Map<String, Object> metadata;
private TrainedModelDefinition.Builder definition;
private List<String> tags;
private TrainedModelDefinition definition;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -246,18 +232,13 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
private Builder setCreatedTime(Instant createdTime) {
this.createdTime = createdTime;
private Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
public Builder setModelVersion(Long modelVersion) {
this.modelVersion = modelVersion;
return this;
}
public Builder setModelType(String modelType) {
this.modelType = modelType;
public Builder setTags(List<String> tags) {
this.tags = tags;
return this;
}
@ -267,6 +248,11 @@ public class TrainedModelConfig implements ToXContentObject {
}
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
this.definition = definition == null ? null : definition.build();
return this;
}
public Builder setDefinition(TrainedModelDefinition definition) {
this.definition = definition;
return this;
}
@ -277,10 +263,9 @@ public class TrainedModelConfig implements ToXContentObject {
createdBy,
version,
description,
createdTime,
modelVersion,
modelType,
definition == null ? null : definition.build(),
createTime,
definition,
tags,
metadata);
}
}

View File

@ -31,6 +31,8 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
@ -58,9 +60,9 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : randomNonNegativeLong(),
randomAlphaOfLength(10),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
}

View File

@ -17,28 +17,30 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_doc";
public static final String NAME = "trained_model_config";
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
public static final ParseField DESCRIPTION = new ParseField("description");
public static final ParseField CREATED_TIME = new ParseField("created_time");
public static final ParseField MODEL_VERSION = new ParseField("model_version");
public static final ParseField CREATE_TIME = new ParseField("create_time");
public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField MODEL_TYPE = new ParseField("model_type");
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
@ -53,16 +55,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
parser.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
parser.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
parser.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
parser.declareField(TrainedModelConfig.Builder::setCreatedTime,
(p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()),
CREATED_TIME,
parser.declareField(TrainedModelConfig.Builder::setCreateTime,
(p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()),
CREATE_TIME,
ObjectParser.ValueType.VALUE);
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
parser.declareObject(TrainedModelConfig.Builder::setDefinition,
(p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
DEFINITION);
parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE);
return parser;
}
@ -70,41 +72,35 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public static String documentId(String modelId, long modelVersion) {
return NAME + "-" + modelId + "-" + modelVersion;
}
private final String modelId;
private final String createdBy;
private final Version version;
private final String description;
private final Instant createdTime;
private final long modelVersion;
private final String modelType;
private final Instant createTime;
private final List<String> tags;
private final Map<String, Object> metadata;
// TODO how to reference and store large models that will not be executed in Java???
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
// TODO Should this be lazily parsed when loading via the index???
private final TrainedModelDefinition definition;
TrainedModelConfig(String modelId,
String createdBy,
Version version,
String description,
Instant createdTime,
Long modelVersion,
String modelType,
Instant createTime,
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
this.createdTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createdTime, CREATED_TIME).toEpochMilli());
this.modelType = ExceptionsHelper.requireNonNull(modelType, MODEL_TYPE);
this.createTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createTime, CREATE_TIME).toEpochMilli());
this.definition = definition;
this.description = description;
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.modelVersion = modelVersion == null ? 0 : modelVersion;
}
public TrainedModelConfig(StreamInput in) throws IOException {
@ -112,10 +108,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
createdBy = in.readString();
version = Version.readVersion(in);
description = in.readOptionalString();
createdTime = in.readInstant();
modelVersion = in.readVLong();
modelType = in.readString();
createTime = in.readInstant();
definition = in.readOptionalWriteable(TrainedModelDefinition::new);
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
metadata = in.readMap();
}
@ -135,16 +130,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return description;
}
public Instant getCreatedTime() {
return createdTime;
public Instant getCreateTime() {
return createTime;
}
public long getModelVersion() {
return modelVersion;
}
public String getModelType() {
return modelType;
public List<String> getTags() {
return tags;
}
public Map<String, Object> getMetadata() {
@ -166,10 +157,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeString(createdBy);
Version.writeVersion(version, out);
out.writeOptionalString(description);
out.writeInstant(createdTime);
out.writeVLong(modelVersion);
out.writeString(modelType);
out.writeInstant(createTime);
out.writeOptionalWriteable(definition);
out.writeCollection(tags, StreamOutput::writeString);
out.writeMap(metadata);
}
@ -182,15 +172,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli());
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
builder.field(MODEL_TYPE.getPreferredName(), modelType);
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
if (definition != null) {
builder.field(DEFINITION.getPreferredName(), definition);
}
builder.field(TAGS.getPreferredName(), tags);
if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata);
}
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.endObject();
return builder;
}
@ -209,10 +201,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Objects.equals(createdBy, that.createdBy) &&
Objects.equals(version, that.version) &&
Objects.equals(description, that.description) &&
Objects.equals(createdTime, that.createdTime) &&
Objects.equals(modelVersion, that.modelVersion) &&
Objects.equals(modelType, that.modelType) &&
Objects.equals(createTime, that.createTime) &&
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(metadata, that.metadata);
}
@ -221,12 +212,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return Objects.hash(modelId,
createdBy,
version,
createdTime,
modelType,
createTime,
definition,
description,
metadata,
modelVersion);
tags,
metadata);
}
public static class Builder {
@ -235,11 +225,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private String createdBy;
private Version version;
private String description;
private Instant createdTime;
private Long modelVersion;
private String modelType;
private Instant createTime;
private List<String> tags = Collections.emptyList();
private Map<String, Object> metadata;
private TrainedModelDefinition.Builder definition;
private TrainedModelDefinition definition;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -265,18 +254,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setCreatedTime(Instant createdTime) {
this.createdTime = createdTime;
public Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
}
public Builder setModelVersion(Long modelVersion) {
this.modelVersion = modelVersion;
return this;
}
public Builder setModelType(String modelType) {
this.modelType = modelType;
public Builder setTags(List<String> tags) {
this.tags = ExceptionsHelper.requireNonNull(tags, TAGS);
return this;
}
@ -286,6 +270,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
}
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
this.definition = definition.build();
return this;
}
public Builder setDefinition(TrainedModelDefinition definition) {
this.definition = definition;
return this;
}
@ -316,9 +305,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
CREATED_BY.getPreferredName());
}
if (createdTime != null) {
if (createTime != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATED_TIME.getPreferredName());
CREATE_TIME.getPreferredName());
}
}
@ -328,23 +317,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
createdBy,
version,
description,
createdTime,
modelVersion,
modelType,
definition == null ? null : definition.build(),
metadata);
}
public TrainedModelConfig build(Version version) {
return new TrainedModelConfig(
modelId,
createdBy,
version,
description,
Instant.now(),
modelVersion,
modelType,
definition == null ? null : definition.build(),
createTime == null ? Instant.now() : createTime,
definition,
tags,
metadata);
}
}

View File

@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.inference.persistence;
import org.elasticsearch.common.ParseField;
/**
* Class containing the index constants so that the index version, name, and prefix are available to a wider audience.
*/
@ -14,6 +16,7 @@ public final class InferenceIndexConstants {
public static final String INDEX_NAME_PREFIX = ".ml-inference-";
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;
public static final ParseField DOC_TYPE = new ParseField("doc_type");
private InferenceIndexConstants() {}

View File

@ -80,11 +80,11 @@ public final class Messages {
public static final String INVALID_GROUP = "Invalid group id ''{0}''; must be non-empty string and may contain lowercase alphanumeric" +
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] with version [{1}] already exists";
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
"Failed to serialize the trained model [{0}] with version [{1}] for storage";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}] with version [{1}]";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";

View File

@ -21,6 +21,7 @@ import org.junit.Before;
import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
@ -29,7 +30,6 @@ import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo;
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
private boolean lenient;
@ -56,15 +56,15 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
@Override
protected TrainedModelConfig createTestInstance() {
List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
return new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : randomNonNegativeLong(),
randomAlphaOfLength(10),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
tags,
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
}
@ -116,9 +116,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedTime(Instant.now())
.setCreateTime(Instant.now())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder()

View File

@ -200,6 +200,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationP
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
@ -523,7 +524,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
client,
clusterService);
normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController, clusterService);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService,
xContentRegistry);
memoryEstimationProcessFactory =
new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService);
} catch (IOException e) {
@ -566,9 +568,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
System::currentTimeMillis, anomalyDetectionAuditor, autodetectProcessManager);
this.datafeedManager.set(datafeedManager);
// Inference components
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
dataFrameAnalyticsAuditor);
dataFrameAnalyticsAuditor, trainedModelProvider);
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
new MemoryUsageEstimationProcessManager(
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);
@ -614,7 +619,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
analyticsProcessManager,
memoryEstimationProcessManager,
dataFrameAnalyticsConfigProvider,
nativeStorageProvider
nativeStorageProvider,
trainedModelProvider
);
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.process.AbstractNativeProcess;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
@ -26,10 +27,11 @@ abstract class AbstractNativeAnalyticsProcess<Result> extends AbstractNativeProc
protected AbstractNativeAnalyticsProcess(String name, ConstructingObjectParser<Result, Void> resultParser, String jobId,
InputStream logStream, OutputStream processInStream,
InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields,
List<Path> filesToDelete, Consumer<String> onProcessCrash) {
List<Path> filesToDelete, Consumer<String> onProcessCrash,
NamedXContentRegistry namedXContentRegistry) {
super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash);
this.name = Objects.requireNonNull(name);
this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser));
this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser), namedXContentRegistry);
}
@Override

View File

@ -34,6 +34,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import java.io.IOException;
@ -57,15 +58,18 @@ public class AnalyticsProcessManager {
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor;
private final TrainedModelProvider trainedModelProvider;
public AnalyticsProcessManager(Client client,
ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor) {
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
}
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
@ -356,7 +360,8 @@ public class AnalyticsProcessManager {
process = createProcess(task, config, analyticsProcessConfig, state);
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
dataExtractorFactory.newExtractor(true));
resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker());
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
trainedModelProvider, auditor);
return true;
}

View File

@ -8,33 +8,51 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import java.time.Instant;
import java.util.Collections;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
public class AnalyticsResultProcessor {
private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
private final String dataFrameAnalyticsId;
private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final Supplier<Boolean> isProcessKilled;
private final ProgressTracker progressTracker;
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure;
public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier<Boolean> isProcessKilled,
ProgressTracker progressTracker) {
this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId);
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor) {
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
this.progressTracker = Objects.requireNonNull(progressTracker);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor);
}
@Nullable
@ -47,7 +65,7 @@ public class AnalyticsResultProcessor {
completionLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", dataFrameAnalyticsId), e);
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e);
}
}
@ -75,7 +93,7 @@ public class AnalyticsResultProcessor {
if (isProcessKilled.get()) {
// No need to log error as it's due to stopping
} else {
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", dataFrameAnalyticsId), e);
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
failure = "error parsing data frame analytics output: [" + e.getMessage() + "]";
}
} finally {
@ -93,5 +111,60 @@ public class AnalyticsResultProcessor {
if (progressPercent != null) {
progressTracker.analyzingPercent.set(progressPercent);
}
TrainedModelDefinition inferenceModel = result.getInferenceModel();
if (inferenceModel != null) {
createAndIndexInferenceModel(inferenceModel);
}
}
private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) {
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
try {
if (latch.await(30, TimeUnit.SECONDS) == false) {
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e);
}
}
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition inferenceModel) {
Instant createTime = Instant.now();
return TrainedModelConfig.builder()
.setModelId(analytics.getId() + "-" + createTime.toEpochMilli())
.setCreatedBy("data-frame-analytics")
.setVersion(Version.CURRENT)
.setCreateTime(createTime)
.setTags(Collections.singletonList(analytics.getId()))
.setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setDefinition(inferenceModel)
.build();
}
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<Boolean> storeListener = ActionListener.wrap(
aBoolean -> {
if (aBoolean == false) {
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
} else {
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
}
},
e -> {
LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(),
trainedModelConfig.getModelId()), e);
auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId()
+ "]; error message [" + e.getMessage() + "]");
}
);
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
return latch;
}
}

View File

@ -6,15 +6,14 @@
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
@ -23,14 +22,14 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
private static final String NAME = "analytics";
private final ProcessResultsParser<AnalyticsResult> resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER);
private final AnalyticsProcessConfig config;
protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream,
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
Consumer<String> onProcessCrash, AnalyticsProcessConfig config) {
Consumer<String> onProcessCrash, AnalyticsProcessConfig config,
NamedXContentRegistry namedXContentRegistry) {
super(NAME, AnalyticsResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields,
filesToDelete, onProcessCrash);
filesToDelete, onProcessCrash, namedXContentRegistry);
this.config = Objects.requireNonNull(config);
}
@ -49,11 +48,6 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData();
}
@Override
public Iterator<AnalyticsResult> readAnalyticsResults() {
return resultsParser.parseResults(processOutStream());
}
@Override
public AnalyticsProcessConfig getConfig() {
return config;

View File

@ -13,6 +13,7 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.env.Environment;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -42,12 +43,15 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
private final Client client;
private final Environment env;
private final NativeController nativeController;
private final NamedXContentRegistry namedXContentRegistry;
private volatile Duration processConnectTimeout;
public NativeAnalyticsProcessFactory(Environment env, Client client, NativeController nativeController, ClusterService clusterService) {
public NativeAnalyticsProcessFactory(Environment env, Client client, NativeController nativeController, ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry) {
this.env = Objects.requireNonNull(env);
this.client = Objects.requireNonNull(client);
this.nativeController = Objects.requireNonNull(nativeController);
this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry);
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT,
this::setProcessConnectTimeout);
@ -73,7 +77,8 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(),
processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(),
processPipes.getRestoreStream().orElse(null), numberOfFields, filesToDelete, onProcessCrash, analyticsProcessConfig);
processPipes.getRestoreStream().orElse(null), numberOfFields, filesToDelete, onProcessCrash, analyticsProcessConfig,
namedXContentRegistry);
try {
startProcess(config, executorService, processPipes, analyticsProcess);

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import java.io.InputStream;
@ -23,7 +24,7 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
Consumer<String> onProcessCrash) {
super(NAME, MemoryUsageEstimationResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream,
numberOfFields, filesToDelete, onProcessCrash);
numberOfFields, filesToDelete, onProcessCrash, NamedXContentRegistry.EMPTY);
}
@Override

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import java.io.IOException;
import java.util.Objects;
@ -20,21 +21,26 @@ public class AnalyticsResult implements ToXContentObject {
public static final ParseField TYPE = new ParseField("analytics_result");
public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
public static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1]));
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition) a[2]));
static {
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.STRICT_PARSER.apply(p, null).build(),
INFERENCE_MODEL);
}
private final RowResults rowResults;
private final Integer progressPercent;
private final TrainedModelDefinition inferenceModel;
public AnalyticsResult(RowResults rowResults, Integer progressPercent) {
public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition inferenceModel) {
this.rowResults = rowResults;
this.progressPercent = progressPercent;
this.inferenceModel = inferenceModel;
}
public RowResults getRowResults() {
@ -45,6 +51,10 @@ public class AnalyticsResult implements ToXContentObject {
return progressPercent;
}
public TrainedModelDefinition getInferenceModel() {
return inferenceModel;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -54,6 +64,9 @@ public class AnalyticsResult implements ToXContentObject {
if (progressPercent != null) {
builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
}
if (inferenceModel != null) {
builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel);
}
builder.endObject();
return builder;
}
@ -68,11 +81,13 @@ public class AnalyticsResult implements ToXContentObject {
}
AnalyticsResult that = (AnalyticsResult) other;
return Objects.equals(rowResults, that.rowResults) && Objects.equals(progressPercent, that.progressPercent);
return Objects.equals(rowResults, that.rowResults)
&& Objects.equals(progressPercent, that.progressPercent)
&& Objects.equals(inferenceModel, that.inferenceModel);
}
@Override
public int hashCode() {
return Objects.hash(rowResults, progressPercent);
return Objects.hash(rowResults, progressPercent, inferenceModel);
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import java.io.IOException;
import java.util.Collections;
@ -23,13 +24,11 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappi
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.addMetaInformation;
/**
* Changelog of internal index versions
*
@ -68,6 +67,12 @@ public final class InferenceInternalIndex {
builder.field(DYNAMIC, "false");
builder.startObject(PROPERTIES);
// Add the doc_type field
builder.startObject(InferenceIndexConstants.DOC_TYPE.getPreferredName())
.field(TYPE, KEYWORD)
.endObject();
addInferenceDocFields(builder);
return builder.endObject()
.endObject()
@ -87,16 +92,13 @@ public final class InferenceInternalIndex {
.startObject(TrainedModelConfig.DESCRIPTION.getPreferredName())
.field(TYPE, TEXT)
.endObject()
.startObject(TrainedModelConfig.CREATED_TIME.getPreferredName())
.startObject(TrainedModelConfig.CREATE_TIME.getPreferredName())
.field(TYPE, DATE)
.endObject()
.startObject(TrainedModelConfig.MODEL_VERSION.getPreferredName())
.field(TYPE, LONG)
.endObject()
.startObject(TrainedModelConfig.DEFINITION.getPreferredName())
.field(ENABLED, false)
.endObject()
.startObject(TrainedModelConfig.MODEL_TYPE.getPreferredName())
.startObject(TrainedModelConfig.TAGS.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(TrainedModelConfig.METADATA.getPreferredName())

View File

@ -37,9 +37,11 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@ -57,26 +59,23 @@ public class TrainedModelProvider {
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
XContentBuilder source = trainedModelConfig.toXContent(builder, ToXContent.EMPTY_PARAMS);
XContentBuilder source = trainedModelConfig.toXContent(builder,
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME)
.opType(DocWriteRequest.OpType.CREATE)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.id(TrainedModelConfig.documentId(trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion()))
.id(trainedModelConfig.getModelId())
.source(source);
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest,
ActionListener.wrap(
r -> listener.onResponse(true),
e -> {
logger.error(
new ParameterizedMessage("[{}][{}] failed to store trained model for inference",
trainedModelConfig.getModelId(),
trainedModelConfig.getModelVersion()),
e);
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e);
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS,
trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion())));
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
} else {
listener.onFailure(
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
@ -93,10 +92,10 @@ public class TrainedModelProvider {
}
}
public void getTrainedModel(String modelId, long modelVersion, ActionListener<TrainedModelConfig> listener) {
public void getTrainedModel(String modelId, ActionListener<TrainedModelConfig> listener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery()
.addIds(TrainedModelConfig.documentId(modelId, modelVersion)));
.addIds(modelId));
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(queryBuilder)
// use sort to get the last
@ -109,11 +108,11 @@ public class TrainedModelProvider {
searchResponse -> {
if (searchResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion)));
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return;
}
BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef();
parseInferenceDocLenientlyFromSource(source, modelId, modelVersion, listener);
parseInferenceDocLenientlyFromSource(source, modelId, listener);
},
listener::onFailure));
}
@ -121,14 +120,13 @@ public class TrainedModelProvider {
private void parseInferenceDocLenientlyFromSource(BytesReference source,
String modelId,
long modelVersion,
ActionListener<TrainedModelConfig> modelListener) {
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build());
} catch (Exception e) {
logger.error(new ParameterizedMessage("[{}][{}] failed to parse model", modelId, modelVersion), e);
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
modelListener.onFailure(e);
}
}

View File

@ -12,14 +12,15 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.env.Environment;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
import org.elasticsearch.xpack.ml.job.results.AutodetectResult;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
@ -78,7 +79,8 @@ public class NativeAutodetectProcessFactory implements AutodetectProcessFactory
int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1;
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, job.getId());
ProcessResultsParser<AutodetectResult> resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER);
ProcessResultsParser<AutodetectResult> resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER,
NamedXContentRegistry.EMPTY);
NativeAutodetectProcess autodetect = new NativeAutodetectProcess(
job.getId(), processPipes.getLogStream().get(), processPipes.getProcessInStream().get(),
processPipes.getProcessOutStream().get(), processPipes.getRestoreStream().orElse(null), numberOfFields,

View File

@ -32,15 +32,17 @@ public class ProcessResultsParser<T> {
private static final Logger logger = LogManager.getLogger(ProcessResultsParser.class);
private final ConstructingObjectParser<T, Void> resultParser;
private final NamedXContentRegistry namedXContentRegistry;
public ProcessResultsParser(ConstructingObjectParser<T, Void> resultParser) {
public ProcessResultsParser(ConstructingObjectParser<T, Void> resultParser, NamedXContentRegistry namedXContentRegistry) {
this.resultParser = Objects.requireNonNull(resultParser);
this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry);
}
public Iterator<T> parseResults(InputStream in) throws ElasticsearchParseException {
try {
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, in);
.createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, in);
XContentParser.Token token = parser.nextToken();
// if start of an array ignore it, we expect an array of results
if (token != XContentParser.Token.START_ARRAY) {

View File

@ -5,21 +5,42 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mockito;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -28,16 +49,29 @@ import static org.mockito.Mockito.when;
public class AnalyticsResultProcessorTests extends ESTestCase {
private static final String JOB_ID = "analytics-result-processor-tests";
private static final String JOB_DESCRIPTION = "This describes the job of these tests";
private AnalyticsProcess<AnalyticsResult> process;
private DataFrameRowsJoiner dataFrameRowsJoiner;
private ProgressTracker progressTracker = new ProgressTracker();
private TrainedModelProvider trainedModelProvider;
private DataFrameAnalyticsAuditor auditor;
private DataFrameAnalyticsConfig analyticsConfig;
@Before
@SuppressWarnings("unchecked")
public void setUpMocks() {
process = mock(AnalyticsProcess.class);
dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class);
trainedModelProvider = mock(TrainedModelProvider.class);
auditor = mock(DataFrameAnalyticsAuditor.class);
analyticsConfig = new DataFrameAnalyticsConfig.Builder()
.setId(JOB_ID)
.setDescription(JOB_DESCRIPTION)
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null))
.setDest(new DataFrameAnalyticsDest("my_dest", null))
.setAnalysis(new Regression("foo"))
.build();
}
public void testProcess_GivenNoResults() {
@ -54,7 +88,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
public void testProcess_GivenEmptyResults() {
givenDataFrameRows(2);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100)));
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50, null), new AnalyticsResult(null, 100, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
@ -69,7 +103,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100)));
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
@ -82,6 +116,71 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
}
@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
givenDataFrameRows(0);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
storeListener.onResponse(true);
return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
resultProcessor.awaitForCompletion();
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
TrainedModelConfig storedModel = storedModelCaptor.getValue();
assertThat(storedModel.getModelId(), containsString(JOB_ID));
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));
assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getDefinition(), equalTo(inferenceModel));
Map<String, Object> metadata = storedModel.getMetadata();
assertThat(metadata.size(), equalTo(1));
assertThat(metadata, hasKey("analytics_config"));
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
true);
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
Mockito.verifyNoMoreInteractions(auditor);
}
@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelFailedToStore() {
givenDataFrameRows(0);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
storeListener.onFailure(new RuntimeException("some failure"));
return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
resultProcessor.awaitForCompletion();
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID));
assertThat(auditCaptor.getValue(), containsString("[some failure]"));
Mockito.verifyNoMoreInteractions(auditor);
}
private void givenProcessResults(List<AnalyticsResult> results) {
when(process.readAnalyticsResults()).thenReturn(results.iterator());
}
@ -93,6 +192,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
}
private AnalyticsResultProcessor createResultProcessor() {
return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, progressTracker);
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider,
auditor);
}
}

View File

@ -5,24 +5,45 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResult> {
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected AnalyticsResult createTestInstance() {
RowResults rowResults = null;
Integer progressPercent = null;
TrainedModelDefinition inferenceModel = null;
if (randomBoolean()) {
rowResults = RowResultsTests.createRandom();
}
if (randomBoolean()) {
progressPercent = randomIntBetween(0, 100);
}
return new AnalyticsResult(rowResults, progressPercent);
if (randomBoolean()) {
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
}
return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
}
@Override

View File

@ -39,7 +39,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
public void testPutTrainedModelConfig() throws Exception {
String modelId = "test-put-trained-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -50,7 +50,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
public void testPutTrainedModelConfigThatAlreadyExists() throws Exception {
String modelId = "test-put-trained-model-config-exists";
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -61,12 +61,12 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId, 0)));
equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId)));
}
public void testGetTrainedModelConfig() throws Exception {
String modelId = "test-get-trained-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -75,7 +75,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder);
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder);
assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(config));
}
@ -84,21 +84,20 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
String modelId = "test-get-missing-trained-model-config";
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder);
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, 0)));
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
}
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
private static TrainedModelConfig buildTrainedModelConfig(String modelId) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setDescription("trained model config for test")
.setModelId(modelId)
.setModelType("binary_decision_tree")
.setModelVersion(modelVersion)
.build(Version.CURRENT);
.setVersion(Version.CURRENT)
.build();
}
@Override

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.job.process.autodetect;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
@ -61,7 +62,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
mock(OutputStream.class), outputStream, mock(OutputStream.class),
NUMBER_FIELDS, null,
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
ZonedDateTime startTime = process.getProcessStartTime();
@ -84,7 +85,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
ByteArrayOutputStream bos = new ByteArrayOutputStream(1024);
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
process.writeRecord(record);
@ -119,7 +120,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
ByteArrayOutputStream bos = new ByteArrayOutputStream(AutodetectControlMsgWriter.FLUSH_SPACES_LENGTH + 1024);
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
FlushJobParams params = FlushJobParams.builder().build();
@ -153,7 +154,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
processInStream, processOutStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
new ProcessResultsParser<AutodetectResult>(AutodetectResult.PARSER), mock(Consumer.class))) {
new ProcessResultsParser<AutodetectResult>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
process.consumeAndCloseOutputStream();
assertThat(processOutStream.available(), equalTo(0));
@ -169,7 +170,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
ByteArrayOutputStream bos = new ByteArrayOutputStream(1024);
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
writeFunction.accept(process);

View File

@ -9,6 +9,7 @@ import com.google.common.base.Charsets;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.test.ESTestCase;
@ -28,7 +29,7 @@ public class ProcessResultsParserTests extends ESTestCase {
public void testParse_GivenEmptyArray() throws IOException {
String json = "[]";
try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) {
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
assertFalse(parser.parseResults(inputStream).hasNext());
}
}
@ -36,7 +37,7 @@ public class ProcessResultsParserTests extends ESTestCase {
public void testParse_GivenUnknownObject() throws IOException {
String json = "[{\"unknown\":{\"id\": 18}}]";
try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) {
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
XContentParseException e = expectThrows(XContentParseException.class,
() -> parser.parseResults(inputStream).forEachRemaining(a -> {
}));
@ -47,7 +48,7 @@ public class ProcessResultsParserTests extends ESTestCase {
public void testParse_GivenArrayContainsAnotherArray() throws IOException {
String json = "[[]]";
try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) {
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class,
() -> parser.parseResults(inputStream).forEachRemaining(a -> {
}));
@ -60,7 +61,7 @@ public class ProcessResultsParserTests extends ESTestCase {
+ " {\"field_1\": \"c\", \"field_2\": 3.0}]";
try (InputStream inputStream = new ByteArrayInputStream(input.getBytes(Charsets.UTF_8))) {
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
Iterator<TestResult> testResultIterator = parser.parseResults(inputStream);
List<TestResult> parsedResults = new ArrayList<>();