From 0dddbb5b426ca4edfe5b59644863ab98713fb411 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 16 Oct 2019 15:46:20 -0400 Subject: [PATCH] [ML] Parse and index inference model (#48016) (#48152) This adds parsing an inference model as a possible result of the analytics process. When we do parse such a model we persist a `TrainedModelConfig` into the inference index that contains additional metadata derived from the running job. --- .../ml/inference/TrainedModelConfig.java | 105 ++++++-------- .../ml/inference/TrainedModelConfigTests.java | 6 +- .../core/ml/inference/TrainedModelConfig.java | 133 +++++++----------- .../persistence/InferenceIndexConstants.java | 3 + .../xpack/core/ml/job/messages/Messages.java | 4 +- .../ml/inference/TrainedModelConfigTests.java | 10 +- .../xpack/ml/MachineLearning.java | 12 +- .../AbstractNativeAnalyticsProcess.java | 6 +- .../process/AnalyticsProcessManager.java | 9 +- .../process/AnalyticsResultProcessor.java | 85 ++++++++++- .../process/NativeAnalyticsProcess.java | 14 +- .../NativeAnalyticsProcessFactory.java | 9 +- .../NativeMemoryUsageEstimationProcess.java | 3 +- .../process/results/AnalyticsResult.java | 23 ++- .../persistence/InferenceInternalIndex.java | 16 ++- .../persistence/TrainedModelProvider.java | 28 ++-- .../NativeAutodetectProcessFactory.java | 6 +- .../ml/process/ProcessResultsParser.java | 6 +- .../AnalyticsResultProcessorTests.java | 106 +++++++++++++- .../process/results/AnalyticsResultTests.java | 23 ++- .../integration/TrainedModelProviderIT.java | 21 ++- .../NativeAutodetectProcessTests.java | 11 +- .../ml/process/ProcessResultsParserTests.java | 9 +- 23 files changed, 420 insertions(+), 228 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 792cc8f7303..f50c9b69eef 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -30,21 +30,21 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.time.Instant; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; public class TrainedModelConfig implements ToXContentObject { - public static final String NAME = "trained_model_doc"; + public static final String NAME = "trained_model_config"; public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); public static final ParseField DESCRIPTION = new ParseField("description"); - public static final ParseField CREATED_TIME = new ParseField("created_time"); - public static final ParseField MODEL_VERSION = new ParseField("model_version"); + public static final ParseField CREATE_TIME = new ParseField("create_time"); public static final ParseField DEFINITION = new ParseField("definition"); - public static final ParseField MODEL_TYPE = new ParseField("model_type"); + public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, @@ -55,16 +55,15 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY); PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION); PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION); - PARSER.declareField(TrainedModelConfig.Builder::setCreatedTime, - (p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()), - CREATED_TIME, + PARSER.declareField(TrainedModelConfig.Builder::setCreateTime, + (p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()), + CREATE_TIME, ObjectParser.ValueType.VALUE); - PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); - PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); - PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setDefinition, (p, c) -> TrainedModelDefinition.fromXContent(p), DEFINITION); + PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); + PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); } public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { @@ -75,30 +74,27 @@ public class TrainedModelConfig implements ToXContentObject { private final String createdBy; private final Version version; private final String description; - private final Instant createdTime; - private final Long modelVersion; - private final String modelType; - private final Map metadata; + private final Instant createTime; private final TrainedModelDefinition definition; + private final List tags; + private final Map metadata; TrainedModelConfig(String modelId, String createdBy, Version version, String description, - Instant createdTime, - Long modelVersion, - String modelType, + Instant createTime, TrainedModelDefinition definition, + List tags, Map metadata) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; - this.createdTime = Instant.ofEpochMilli(createdTime.toEpochMilli()); - this.modelType = modelType; + this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli()); this.definition = definition; this.description = description; + this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); - this.modelVersion = modelVersion; } public String getModelId() { @@ -117,16 +113,12 @@ public class TrainedModelConfig implements ToXContentObject { return description; } - public Instant getCreatedTime() { - return createdTime; + public Instant getCreateTime() { + return createTime; } - public Long getModelVersion() { - return modelVersion; - } - - public String getModelType() { - return modelType; + public List getTags() { + return tags; } public Map getMetadata() { @@ -156,18 +148,15 @@ public class TrainedModelConfig implements ToXContentObject { if (description != null) { builder.field(DESCRIPTION.getPreferredName(), description); } - if (createdTime != null) { - builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli()); - } - if (modelVersion != null) { - builder.field(MODEL_VERSION.getPreferredName(), modelVersion); - } - if (modelType != null) { - builder.field(MODEL_TYPE.getPreferredName(), modelType); + if (createTime != null) { + builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); } if (definition != null) { builder.field(DEFINITION.getPreferredName(), definition); } + if (tags != null) { + builder.field(TAGS.getPreferredName(), tags); + } if (metadata != null) { builder.field(METADATA.getPreferredName(), metadata); } @@ -189,10 +178,9 @@ public class TrainedModelConfig implements ToXContentObject { Objects.equals(createdBy, that.createdBy) && Objects.equals(version, that.version) && Objects.equals(description, that.description) && - Objects.equals(createdTime, that.createdTime) && - Objects.equals(modelVersion, that.modelVersion) && - Objects.equals(modelType, that.modelType) && + Objects.equals(createTime, that.createTime) && Objects.equals(definition, that.definition) && + Objects.equals(tags, that.tags) && Objects.equals(metadata, that.metadata); } @@ -201,12 +189,11 @@ public class TrainedModelConfig implements ToXContentObject { return Objects.hash(modelId, createdBy, version, - createdTime, - modelType, + createTime, definition, description, - metadata, - modelVersion); + tags, + metadata); } @@ -216,11 +203,10 @@ public class TrainedModelConfig implements ToXContentObject { private String createdBy; private Version version; private String description; - private Instant createdTime; - private Long modelVersion; - private String modelType; + private Instant createTime; private Map metadata; - private TrainedModelDefinition.Builder definition; + private List tags; + private TrainedModelDefinition definition; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -246,18 +232,13 @@ public class TrainedModelConfig implements ToXContentObject { return this; } - private Builder setCreatedTime(Instant createdTime) { - this.createdTime = createdTime; + private Builder setCreateTime(Instant createTime) { + this.createTime = createTime; return this; } - public Builder setModelVersion(Long modelVersion) { - this.modelVersion = modelVersion; - return this; - } - - public Builder setModelType(String modelType) { - this.modelType = modelType; + public Builder setTags(List tags) { + this.tags = tags; return this; } @@ -267,6 +248,11 @@ public class TrainedModelConfig implements ToXContentObject { } public Builder setDefinition(TrainedModelDefinition.Builder definition) { + this.definition = definition == null ? null : definition.build(); + return this; + } + + public Builder setDefinition(TrainedModelDefinition definition) { this.definition = definition; return this; } @@ -277,10 +263,9 @@ public class TrainedModelConfig implements ToXContentObject { createdBy, version, description, - createdTime, - modelVersion, - modelType, - definition == null ? null : definition.build(), + createTime, + definition, + tags, metadata); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index e28ca416e24..7825a1bd955 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -31,6 +31,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class TrainedModelConfigTests extends AbstractXContentTestCase { @@ -58,9 +60,9 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 6140c438783..e1c24eee02b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -17,28 +17,30 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.time.Instant; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; public class TrainedModelConfig implements ToXContentObject, Writeable { - public static final String NAME = "trained_model_doc"; + public static final String NAME = "trained_model_config"; public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); public static final ParseField DESCRIPTION = new ParseField("description"); - public static final ParseField CREATED_TIME = new ParseField("created_time"); - public static final ParseField MODEL_VERSION = new ParseField("model_version"); + public static final ParseField CREATE_TIME = new ParseField("create_time"); public static final ParseField DEFINITION = new ParseField("definition"); - public static final ParseField MODEL_TYPE = new ParseField("model_type"); + public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly @@ -53,16 +55,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { parser.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY); parser.declareString(TrainedModelConfig.Builder::setVersion, VERSION); parser.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION); - parser.declareField(TrainedModelConfig.Builder::setCreatedTime, - (p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()), - CREATED_TIME, + parser.declareField(TrainedModelConfig.Builder::setCreateTime, + (p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()), + CREATE_TIME, ObjectParser.ValueType.VALUE); - parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); - parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); + parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); parser.declareObject(TrainedModelConfig.Builder::setDefinition, (p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields), DEFINITION); + parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE); return parser; } @@ -70,41 +72,35 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); } - public static String documentId(String modelId, long modelVersion) { - return NAME + "-" + modelId + "-" + modelVersion; - } - - private final String modelId; private final String createdBy; private final Version version; private final String description; - private final Instant createdTime; - private final long modelVersion; - private final String modelType; + private final Instant createTime; + private final List tags; private final Map metadata; + // TODO how to reference and store large models that will not be executed in Java??? // Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something // TODO Should this be lazily parsed when loading via the index??? private final TrainedModelDefinition definition; + TrainedModelConfig(String modelId, String createdBy, Version version, String description, - Instant createdTime, - Long modelVersion, - String modelType, + Instant createTime, TrainedModelDefinition definition, + List tags, Map metadata) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); - this.createdTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createdTime, CREATED_TIME).toEpochMilli()); - this.modelType = ExceptionsHelper.requireNonNull(modelType, MODEL_TYPE); + this.createTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createTime, CREATE_TIME).toEpochMilli()); this.definition = definition; this.description = description; + this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); - this.modelVersion = modelVersion == null ? 0 : modelVersion; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -112,10 +108,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { createdBy = in.readString(); version = Version.readVersion(in); description = in.readOptionalString(); - createdTime = in.readInstant(); - modelVersion = in.readVLong(); - modelType = in.readString(); + createTime = in.readInstant(); definition = in.readOptionalWriteable(TrainedModelDefinition::new); + tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); metadata = in.readMap(); } @@ -135,16 +130,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return description; } - public Instant getCreatedTime() { - return createdTime; + public Instant getCreateTime() { + return createTime; } - public long getModelVersion() { - return modelVersion; - } - - public String getModelType() { - return modelType; + public List getTags() { + return tags; } public Map getMetadata() { @@ -166,10 +157,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { out.writeString(createdBy); Version.writeVersion(version, out); out.writeOptionalString(description); - out.writeInstant(createdTime); - out.writeVLong(modelVersion); - out.writeString(modelType); + out.writeInstant(createTime); out.writeOptionalWriteable(definition); + out.writeCollection(tags, StreamOutput::writeString); out.writeMap(metadata); } @@ -182,15 +172,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { if (description != null) { builder.field(DESCRIPTION.getPreferredName(), description); } - builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli()); - builder.field(MODEL_VERSION.getPreferredName(), modelVersion); - builder.field(MODEL_TYPE.getPreferredName(), modelType); + builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); if (definition != null) { builder.field(DEFINITION.getPreferredName(), definition); } + builder.field(TAGS.getPreferredName(), tags); if (metadata != null) { builder.field(METADATA.getPreferredName(), metadata); } + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); + } builder.endObject(); return builder; } @@ -209,10 +201,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { Objects.equals(createdBy, that.createdBy) && Objects.equals(version, that.version) && Objects.equals(description, that.description) && - Objects.equals(createdTime, that.createdTime) && - Objects.equals(modelVersion, that.modelVersion) && - Objects.equals(modelType, that.modelType) && + Objects.equals(createTime, that.createTime) && Objects.equals(definition, that.definition) && + Objects.equals(tags, that.tags) && Objects.equals(metadata, that.metadata); } @@ -221,12 +212,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return Objects.hash(modelId, createdBy, version, - createdTime, - modelType, + createTime, definition, description, - metadata, - modelVersion); + tags, + metadata); } public static class Builder { @@ -235,11 +225,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { private String createdBy; private Version version; private String description; - private Instant createdTime; - private Long modelVersion; - private String modelType; + private Instant createTime; + private List tags = Collections.emptyList(); private Map metadata; - private TrainedModelDefinition.Builder definition; + private TrainedModelDefinition definition; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -265,18 +254,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } - public Builder setCreatedTime(Instant createdTime) { - this.createdTime = createdTime; + public Builder setCreateTime(Instant createTime) { + this.createTime = createTime; return this; } - public Builder setModelVersion(Long modelVersion) { - this.modelVersion = modelVersion; - return this; - } - - public Builder setModelType(String modelType) { - this.modelType = modelType; + public Builder setTags(List tags) { + this.tags = ExceptionsHelper.requireNonNull(tags, TAGS); return this; } @@ -286,6 +270,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { } public Builder setDefinition(TrainedModelDefinition.Builder definition) { + this.definition = definition.build(); + return this; + } + + public Builder setDefinition(TrainedModelDefinition definition) { this.definition = definition; return this; } @@ -316,9 +305,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { CREATED_BY.getPreferredName()); } - if (createdTime != null) { + if (createTime != null) { throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", - CREATED_TIME.getPreferredName()); + CREATE_TIME.getPreferredName()); } } @@ -328,23 +317,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { createdBy, version, description, - createdTime, - modelVersion, - modelType, - definition == null ? null : definition.build(), - metadata); - } - - public TrainedModelConfig build(Version version) { - return new TrainedModelConfig( - modelId, - createdBy, - version, - description, - Instant.now(), - modelVersion, - modelType, - definition == null ? null : definition.build(), + createTime == null ? Instant.now() : createTime, + definition, + tags, metadata); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index e5820f4068e..8621786cedd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.inference.persistence; +import org.elasticsearch.common.ParseField; + /** * Class containing the index constants so that the index version, name, and prefix are available to a wider audience. */ @@ -14,6 +16,7 @@ public final class InferenceIndexConstants { public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; + public static final ParseField DOC_TYPE = new ParseField("doc_type"); private InferenceIndexConstants() {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index c302d04186a..75cc468160d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -80,11 +80,11 @@ public final class Messages { public static final String INVALID_GROUP = "Invalid group id ''{0}''; must be non-empty string and may contain lowercase alphanumeric" + " (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric"; - public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] with version [{1}] already exists"; + public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = "Failed to serialize the trained model [{0}] with version [{1}] for storage"; - public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}] with version [{1}]"; + public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 87c981ebd4a..3117590307f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -21,6 +21,7 @@ import org.junit.Before; import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Predicate; @@ -29,7 +30,6 @@ import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; - public class TrainedModelConfigTests extends AbstractSerializingTestCase { private boolean lenient; @@ -56,15 +56,15 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false)); return new TrainedModelConfig( randomAlphaOfLength(10), randomAlphaOfLength(10), Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), Instant.ofEpochMilli(randomNonNegativeLong()), - randomBoolean() ? null : randomNonNegativeLong(), - randomAlphaOfLength(10), randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), + tags, randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); } @@ -116,9 +116,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase TrainedModelConfig.builder() .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) - .setCreatedTime(Instant.now()) + .setCreateTime(Instant.now()) .setModelId(modelId).validate()); - assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation")); + assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation")); ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index d3ca22a62a5..ab2c7a00f08 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -200,6 +200,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationP import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -523,7 +524,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu client, clusterService); normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController, clusterService); - analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService); + analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService, + xContentRegistry); memoryEstimationProcessFactory = new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService); } catch (IOException e) { @@ -566,9 +568,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu System::currentTimeMillis, anomalyDetectionAuditor, autodetectProcessManager); this.datafeedManager.set(datafeedManager); + // Inference components + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, - dataFrameAnalyticsAuditor); + dataFrameAnalyticsAuditor, trainedModelProvider); MemoryUsageEstimationProcessManager memoryEstimationProcessManager = new MemoryUsageEstimationProcessManager( threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory); @@ -614,7 +619,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu analyticsProcessManager, memoryEstimationProcessManager, dataFrameAnalyticsConfigProvider, - nativeStorageProvider + nativeStorageProvider, + trainedModelProvider ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AbstractNativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AbstractNativeAnalyticsProcess.java index 55481de160b..6ceaa4ce9cd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AbstractNativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AbstractNativeAnalyticsProcess.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; import org.elasticsearch.xpack.ml.process.ProcessResultsParser; @@ -26,10 +27,11 @@ abstract class AbstractNativeAnalyticsProcess extends AbstractNativeProc protected AbstractNativeAnalyticsProcess(String name, ConstructingObjectParser resultParser, String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields, - List filesToDelete, Consumer onProcessCrash) { + List filesToDelete, Consumer onProcessCrash, + NamedXContentRegistry namedXContentRegistry) { super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); this.name = Objects.requireNonNull(name); - this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser)); + this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser), namedXContentRegistry); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index b1022ecc7fe..485b9d9d605 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor; import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import java.io.IOException; @@ -57,15 +58,18 @@ public class AnalyticsProcessManager { private final AnalyticsProcessFactory processFactory; private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); private final DataFrameAnalyticsAuditor auditor; + private final TrainedModelProvider trainedModelProvider; public AnalyticsProcessManager(Client client, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory, - DataFrameAnalyticsAuditor auditor) { + DataFrameAnalyticsAuditor auditor, + TrainedModelProvider trainedModelProvider) { this.client = Objects.requireNonNull(client); this.threadPool = Objects.requireNonNull(threadPool); this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.auditor = Objects.requireNonNull(auditor); + this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); } public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, @@ -356,7 +360,8 @@ public class AnalyticsProcessManager { process = createProcess(task, config, analyticsProcessConfig, state); DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); - resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker()); + resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(), + trainedModelProvider, auditor); return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 079feecbab6..72337f77e9f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -8,33 +8,51 @@ package org.elasticsearch.xpack.ml.dataframe.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import java.time.Instant; +import java.util.Collections; import java.util.Iterator; import java.util.Objects; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); - private final String dataFrameAnalyticsId; + private final DataFrameAnalyticsConfig analytics; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final Supplier isProcessKilled; private final ProgressTracker progressTracker; + private final TrainedModelProvider trainedModelProvider; + private final DataFrameAnalyticsAuditor auditor; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; - public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier isProcessKilled, - ProgressTracker progressTracker) { - this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId); + public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, + Supplier isProcessKilled, ProgressTracker progressTracker, + TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor) { + this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.isProcessKilled = Objects.requireNonNull(isProcessKilled); this.progressTracker = Objects.requireNonNull(progressTracker); + this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); + this.auditor = Objects.requireNonNull(auditor); } @Nullable @@ -47,7 +65,7 @@ public class AnalyticsResultProcessor { completionLatch.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", dataFrameAnalyticsId), e); + LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e); } } @@ -75,7 +93,7 @@ public class AnalyticsResultProcessor { if (isProcessKilled.get()) { // No need to log error as it's due to stopping } else { - LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", dataFrameAnalyticsId), e); + LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e); failure = "error parsing data frame analytics output: [" + e.getMessage() + "]"; } } finally { @@ -93,5 +111,60 @@ public class AnalyticsResultProcessor { if (progressPercent != null) { progressTracker.analyzingPercent.set(progressPercent); } + TrainedModelDefinition inferenceModel = result.getInferenceModel(); + if (inferenceModel != null) { + createAndIndexInferenceModel(inferenceModel); + } + } + + private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) { + TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel); + CountDownLatch latch = storeTrainedModel(trainedModelConfig); + + try { + if (latch.await(30, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e); + } + } + + private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition inferenceModel) { + Instant createTime = Instant.now(); + return TrainedModelConfig.builder() + .setModelId(analytics.getId() + "-" + createTime.toEpochMilli()) + .setCreatedBy("data-frame-analytics") + .setVersion(Version.CURRENT) + .setCreateTime(createTime) + .setTags(Collections.singletonList(analytics.getId())) + .setDescription(analytics.getDescription()) + .setMetadata(Collections.singletonMap("analytics_config", + XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) + .setDefinition(inferenceModel) + .build(); + } + + private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) { + CountDownLatch latch = new CountDownLatch(1); + ActionListener storeListener = ActionListener.wrap( + aBoolean -> { + if (aBoolean == false) { + LOGGER.error("[{}] Storing trained model responded false", analytics.getId()); + } else { + LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); + auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]"); + } + }, + e -> { + LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(), + trainedModelConfig.getModelId()), e); + auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId() + + "]; error message [" + e.getMessage() + "]"); + } + ); + trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); + return latch; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java index 53578859582..e606f533ce2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java @@ -6,15 +6,14 @@ package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; -import org.elasticsearch.xpack.ml.process.ProcessResultsParser; import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.file.Path; -import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.function.Consumer; @@ -23,14 +22,14 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER); private final AnalyticsProcessConfig config; protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields, List filesToDelete, - Consumer onProcessCrash, AnalyticsProcessConfig config) { + Consumer onProcessCrash, AnalyticsProcessConfig config, + NamedXContentRegistry namedXContentRegistry) { super(NAME, AnalyticsResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, - filesToDelete, onProcessCrash); + filesToDelete, onProcessCrash, namedXContentRegistry); this.config = Objects.requireNonNull(config); } @@ -49,11 +48,6 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess readAnalyticsResults() { - return resultsParser.parseResults(processOutStream()); - } - @Override public AnalyticsProcessConfig getConfig() { return config; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java index 39e62869b44..78c45537034 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -42,12 +43,15 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory filesToDelete, Consumer onProcessCrash) { super(NAME, MemoryUsageEstimationResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, - numberOfFields, filesToDelete, onProcessCrash); + numberOfFields, filesToDelete, onProcessCrash, NamedXContentRegistry.EMPTY); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 8118c3645f1..c383fd19576 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import java.io.IOException; import java.util.Objects; @@ -20,21 +21,26 @@ public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + public static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1])); + a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition) a[2])); static { PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.STRICT_PARSER.apply(p, null).build(), + INFERENCE_MODEL); } private final RowResults rowResults; private final Integer progressPercent; + private final TrainedModelDefinition inferenceModel; - public AnalyticsResult(RowResults rowResults, Integer progressPercent) { + public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition inferenceModel) { this.rowResults = rowResults; this.progressPercent = progressPercent; + this.inferenceModel = inferenceModel; } public RowResults getRowResults() { @@ -45,6 +51,10 @@ public class AnalyticsResult implements ToXContentObject { return progressPercent; } + public TrainedModelDefinition getInferenceModel() { + return inferenceModel; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -54,6 +64,9 @@ public class AnalyticsResult implements ToXContentObject { if (progressPercent != null) { builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent); } + if (inferenceModel != null) { + builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel); + } builder.endObject(); return builder; } @@ -68,11 +81,13 @@ public class AnalyticsResult implements ToXContentObject { } AnalyticsResult that = (AnalyticsResult) other; - return Objects.equals(rowResults, that.rowResults) && Objects.equals(progressPercent, that.progressPercent); + return Objects.equals(rowResults, that.rowResults) + && Objects.equals(progressPercent, that.progressPercent) + && Objects.equals(inferenceModel, that.inferenceModel); } @Override public int hashCode() { - return Objects.hash(rowResults, progressPercent); + return Objects.hash(rowResults, progressPercent, inferenceModel); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java index 2f1cf2aed4e..33a4180b25f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import java.io.IOException; import java.util.Collections; @@ -23,13 +24,11 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappi import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD; -import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.addMetaInformation; - /** * Changelog of internal index versions * @@ -68,6 +67,12 @@ public final class InferenceInternalIndex { builder.field(DYNAMIC, "false"); builder.startObject(PROPERTIES); + + // Add the doc_type field + builder.startObject(InferenceIndexConstants.DOC_TYPE.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject(); + addInferenceDocFields(builder); return builder.endObject() .endObject() @@ -87,16 +92,13 @@ public final class InferenceInternalIndex { .startObject(TrainedModelConfig.DESCRIPTION.getPreferredName()) .field(TYPE, TEXT) .endObject() - .startObject(TrainedModelConfig.CREATED_TIME.getPreferredName()) + .startObject(TrainedModelConfig.CREATE_TIME.getPreferredName()) .field(TYPE, DATE) .endObject() - .startObject(TrainedModelConfig.MODEL_VERSION.getPreferredName()) - .field(TYPE, LONG) - .endObject() .startObject(TrainedModelConfig.DEFINITION.getPreferredName()) .field(ENABLED, false) .endObject() - .startObject(TrainedModelConfig.MODEL_TYPE.getPreferredName()) + .startObject(TrainedModelConfig.TAGS.getPreferredName()) .field(TYPE, KEYWORD) .endObject() .startObject(TrainedModelConfig.METADATA.getPreferredName()) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 2028dfe9edf..6f1e543896c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -37,9 +37,11 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.io.InputStream; +import java.util.Collections; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -57,26 +59,23 @@ public class TrainedModelProvider { public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener listener) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - XContentBuilder source = trainedModelConfig.toXContent(builder, ToXContent.EMPTY_PARAMS); + XContentBuilder source = trainedModelConfig.toXContent(builder, + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME) .opType(DocWriteRequest.OpType.CREATE) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .id(TrainedModelConfig.documentId(trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion())) + .id(trainedModelConfig.getModelId()) .source(source); executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap( r -> listener.onResponse(true), e -> { - logger.error( - new ParameterizedMessage("[{}][{}] failed to store trained model for inference", - trainedModelConfig.getModelId(), - trainedModelConfig.getModelVersion()), - e); + logger.error(new ParameterizedMessage( + "[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e); if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { listener.onFailure(new ResourceAlreadyExistsException( - Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, - trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion()))); + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); } else { listener.onFailure( new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, @@ -93,10 +92,10 @@ public class TrainedModelProvider { } } - public void getTrainedModel(String modelId, long modelVersion, ActionListener listener) { + public void getTrainedModel(String modelId, ActionListener listener) { QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders .idsQuery() - .addIds(TrainedModelConfig.documentId(modelId, modelVersion))); + .addIds(modelId)); SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) .setQuery(queryBuilder) // use sort to get the last @@ -109,11 +108,11 @@ public class TrainedModelProvider { searchResponse -> { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion))); + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return; } BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef(); - parseInferenceDocLenientlyFromSource(source, modelId, modelVersion, listener); + parseInferenceDocLenientlyFromSource(source, modelId, listener); }, listener::onFailure)); } @@ -121,14 +120,13 @@ public class TrainedModelProvider { private void parseInferenceDocLenientlyFromSource(BytesReference source, String modelId, - long modelVersion, ActionListener modelListener) { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build()); } catch (Exception e) { - logger.error(new ParameterizedMessage("[{}][{}] failed to parse model", modelId, modelVersion), e); + logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); modelListener.onFailure(e); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java index 26f9353639c..716000472ac 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java @@ -12,14 +12,15 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.process.IndexingStateProcessor; import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams; import org.elasticsearch.xpack.ml.job.results.AutodetectResult; +import org.elasticsearch.xpack.ml.process.IndexingStateProcessor; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.ProcessPipes; import org.elasticsearch.xpack.ml.process.ProcessResultsParser; @@ -78,7 +79,8 @@ public class NativeAutodetectProcessFactory implements AutodetectProcessFactory int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1; IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, job.getId()); - ProcessResultsParser resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER); + ProcessResultsParser resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER, + NamedXContentRegistry.EMPTY); NativeAutodetectProcess autodetect = new NativeAutodetectProcess( job.getId(), processPipes.getLogStream().get(), processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), processPipes.getRestoreStream().orElse(null), numberOfFields, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java index 609c45659dd..175fc180512 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java @@ -32,15 +32,17 @@ public class ProcessResultsParser { private static final Logger logger = LogManager.getLogger(ProcessResultsParser.class); private final ConstructingObjectParser resultParser; + private final NamedXContentRegistry namedXContentRegistry; - public ProcessResultsParser(ConstructingObjectParser resultParser) { + public ProcessResultsParser(ConstructingObjectParser resultParser, NamedXContentRegistry namedXContentRegistry) { this.resultParser = Objects.requireNonNull(resultParser); + this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry); } public Iterator parseResults(InputStream in) throws ElasticsearchParseException { try { XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, in); + .createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, in); XContentParser.Token token = parser.nextToken(); // if start of an array ignore it, we expect an array of results if (token != XContentParser.Token.START_ARRAY) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 60bfd0eb3e3..8612f263a0c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,21 +5,42 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Mockito; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -28,16 +49,29 @@ import static org.mockito.Mockito.when; public class AnalyticsResultProcessorTests extends ESTestCase { private static final String JOB_ID = "analytics-result-processor-tests"; + private static final String JOB_DESCRIPTION = "This describes the job of these tests"; private AnalyticsProcess process; private DataFrameRowsJoiner dataFrameRowsJoiner; private ProgressTracker progressTracker = new ProgressTracker(); + private TrainedModelProvider trainedModelProvider; + private DataFrameAnalyticsAuditor auditor; + private DataFrameAnalyticsConfig analyticsConfig; @Before @SuppressWarnings("unchecked") public void setUpMocks() { process = mock(AnalyticsProcess.class); dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class); + trainedModelProvider = mock(TrainedModelProvider.class); + auditor = mock(DataFrameAnalyticsAuditor.class); + analyticsConfig = new DataFrameAnalyticsConfig.Builder() + .setId(JOB_ID) + .setDescription(JOB_DESCRIPTION) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)) + .setAnalysis(new Regression("foo")) + .build(); } public void testProcess_GivenNoResults() { @@ -54,7 +88,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { public void testProcess_GivenEmptyResults() { givenDataFrameRows(2); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50, null), new AnalyticsResult(null, 100, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -69,7 +103,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -82,6 +116,71 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(progressTracker.writingResultsPercent.get(), equalTo(100)); } + @SuppressWarnings("unchecked") + public void testProcess_GivenInferenceModelIsStoredSuccessfully() { + givenDataFrameRows(0); + + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onResponse(true); + return null; + }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); + + TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel))); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); + verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class)); + + TrainedModelConfig storedModel = storedModelCaptor.getValue(); + assertThat(storedModel.getModelId(), containsString(JOB_ID)); + assertThat(storedModel.getVersion(), equalTo(Version.CURRENT)); + assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics")); + assertThat(storedModel.getTags(), contains(JOB_ID)); + assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); + assertThat(storedModel.getDefinition(), equalTo(inferenceModel)); + Map metadata = storedModel.getMetadata(); + assertThat(metadata.size(), equalTo(1)); + assertThat(metadata, hasKey("analytics_config")); + Map analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(), + true); + assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config"))); + + ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); + verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); + assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); + Mockito.verifyNoMoreInteractions(auditor); + } + + @SuppressWarnings("unchecked") + public void testProcess_GivenInferenceModelFailedToStore() { + givenDataFrameRows(0); + + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onFailure(new RuntimeException("some failure")); + return null; + }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); + + TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel))); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + // This test verifies the processor knows how to handle a failure on storing the model and completes normally + ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); + verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); + assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID)); + assertThat(auditCaptor.getValue(), containsString("[some failure]")); + Mockito.verifyNoMoreInteractions(auditor); + } + private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } @@ -93,6 +192,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { } private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, progressTracker); + return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider, + auditor); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 13ef2ac5024..f2860339e32 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -5,24 +5,45 @@ */ package org.elasticsearch.xpack.ml.dataframe.process.results; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; public class AnalyticsResultTests extends AbstractXContentTestCase { + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + @Override protected AnalyticsResult createTestInstance() { RowResults rowResults = null; Integer progressPercent = null; + TrainedModelDefinition inferenceModel = null; if (randomBoolean()) { rowResults = RowResultsTests.createRandom(); } if (randomBoolean()) { progressPercent = randomIntBetween(0, 100); } - return new AnalyticsResult(rowResults, progressPercent); + if (randomBoolean()) { + inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); + } + return new AnalyticsResult(rowResults, progressPercent, inferenceModel); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index ae7e1e4dd8d..73204791ed8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -39,7 +39,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { public void testPutTrainedModelConfig() throws Exception { String modelId = "test-put-trained-model-config"; - TrainedModelConfig config = buildTrainedModelConfig(modelId, 0); + TrainedModelConfig config = buildTrainedModelConfig(modelId); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -50,7 +50,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { public void testPutTrainedModelConfigThatAlreadyExists() throws Exception { String modelId = "test-put-trained-model-config-exists"; - TrainedModelConfig config = buildTrainedModelConfig(modelId, 0); + TrainedModelConfig config = buildTrainedModelConfig(modelId); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -61,12 +61,12 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), - equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId, 0))); + equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId))); } public void testGetTrainedModelConfig() throws Exception { String modelId = "test-get-trained-model-config"; - TrainedModelConfig config = buildTrainedModelConfig(modelId, 0); + TrainedModelConfig config = buildTrainedModelConfig(modelId); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -75,7 +75,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { assertThat(exceptionHolder.get(), is(nullValue())); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder); assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), equalTo(config)); } @@ -84,21 +84,20 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { String modelId = "test-get-missing-trained-model-config"; AtomicReference getConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), - equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, 0))); + equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); } - private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { + private static TrainedModelConfig buildTrainedModelConfig(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") .setModelId(modelId) - .setModelType("binary_decision_tree") - .setModelVersion(modelVersion) - .build(Version.CURRENT); + .setVersion(Version.CURRENT) + .build(); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java index 9cd6343eab7..07a191a4223 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.ml.job.process.autodetect; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig; import org.elasticsearch.xpack.ml.process.IndexingStateProcessor; @@ -61,7 +62,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, mock(OutputStream.class), outputStream, mock(OutputStream.class), NUMBER_FIELDS, null, - new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) { process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class)); ZonedDateTime startTime = process.getProcessStartTime(); @@ -84,7 +85,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) { process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class)); process.writeRecord(record); @@ -119,7 +120,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { ByteArrayOutputStream bos = new ByteArrayOutputStream(AutodetectControlMsgWriter.FLUSH_SPACES_LENGTH + 1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) { process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class)); FlushJobParams params = FlushJobParams.builder().build(); @@ -153,7 +154,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, processInStream, processOutStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new ProcessResultsParser(AutodetectResult.PARSER), mock(Consumer.class))) { + new ProcessResultsParser(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) { process.consumeAndCloseOutputStream(); assertThat(processOutStream.available(), equalTo(0)); @@ -169,7 +170,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) { process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class)); writeFunction.accept(process); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java index 32ab15a2701..dff432ab938 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java @@ -9,6 +9,7 @@ import com.google.common.base.Charsets; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.test.ESTestCase; @@ -28,7 +29,7 @@ public class ProcessResultsParserTests extends ESTestCase { public void testParse_GivenEmptyArray() throws IOException { String json = "[]"; try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY); assertFalse(parser.parseResults(inputStream).hasNext()); } } @@ -36,7 +37,7 @@ public class ProcessResultsParserTests extends ESTestCase { public void testParse_GivenUnknownObject() throws IOException { String json = "[{\"unknown\":{\"id\": 18}}]"; try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY); XContentParseException e = expectThrows(XContentParseException.class, () -> parser.parseResults(inputStream).forEachRemaining(a -> { })); @@ -47,7 +48,7 @@ public class ProcessResultsParserTests extends ESTestCase { public void testParse_GivenArrayContainsAnotherArray() throws IOException { String json = "[[]]"; try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY); ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, () -> parser.parseResults(inputStream).forEachRemaining(a -> { })); @@ -60,7 +61,7 @@ public class ProcessResultsParserTests extends ESTestCase { + " {\"field_1\": \"c\", \"field_2\": 3.0}]"; try (InputStream inputStream = new ByteArrayInputStream(input.getBytes(Charsets.UTF_8))) { - ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY); Iterator testResultIterator = parser.parseResults(inputStream); List parsedResults = new ArrayList<>();