This adds parsing an inference model as a possible result of the analytics process. When we do parse such a model we persist a `TrainedModelConfig` into the inference index that contains additional metadata derived from the running job.
This commit is contained in:
parent
74812f78dd
commit
0dddbb5b42
|
@ -30,21 +30,21 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TrainedModelConfig implements ToXContentObject {
|
||||
|
||||
public static final String NAME = "trained_model_doc";
|
||||
public static final String NAME = "trained_model_config";
|
||||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
public static final ParseField CREATED_BY = new ParseField("created_by");
|
||||
public static final ParseField VERSION = new ParseField("version");
|
||||
public static final ParseField DESCRIPTION = new ParseField("description");
|
||||
public static final ParseField CREATED_TIME = new ParseField("created_time");
|
||||
public static final ParseField MODEL_VERSION = new ParseField("model_version");
|
||||
public static final ParseField CREATE_TIME = new ParseField("create_time");
|
||||
public static final ParseField DEFINITION = new ParseField("definition");
|
||||
public static final ParseField MODEL_TYPE = new ParseField("model_type");
|
||||
public static final ParseField TAGS = new ParseField("tags");
|
||||
public static final ParseField METADATA = new ParseField("metadata");
|
||||
|
||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||
|
@ -55,16 +55,15 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
PARSER.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
|
||||
PARSER.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
|
||||
PARSER.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
|
||||
PARSER.declareField(TrainedModelConfig.Builder::setCreatedTime,
|
||||
(p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()),
|
||||
CREATED_TIME,
|
||||
PARSER.declareField(TrainedModelConfig.Builder::setCreateTime,
|
||||
(p, c) -> TimeUtil.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()),
|
||||
CREATE_TIME,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
||||
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setDefinition,
|
||||
(p, c) -> TrainedModelDefinition.fromXContent(p),
|
||||
DEFINITION);
|
||||
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||
}
|
||||
|
||||
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
|
||||
|
@ -75,30 +74,27 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private final String createdBy;
|
||||
private final Version version;
|
||||
private final String description;
|
||||
private final Instant createdTime;
|
||||
private final Long modelVersion;
|
||||
private final String modelType;
|
||||
private final Map<String, Object> metadata;
|
||||
private final Instant createTime;
|
||||
private final TrainedModelDefinition definition;
|
||||
private final List<String> tags;
|
||||
private final Map<String, Object> metadata;
|
||||
|
||||
TrainedModelConfig(String modelId,
|
||||
String createdBy,
|
||||
Version version,
|
||||
String description,
|
||||
Instant createdTime,
|
||||
Long modelVersion,
|
||||
String modelType,
|
||||
Instant createTime,
|
||||
TrainedModelDefinition definition,
|
||||
List<String> tags,
|
||||
Map<String, Object> metadata) {
|
||||
this.modelId = modelId;
|
||||
this.createdBy = createdBy;
|
||||
this.version = version;
|
||||
this.createdTime = Instant.ofEpochMilli(createdTime.toEpochMilli());
|
||||
this.modelType = modelType;
|
||||
this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
|
||||
this.definition = definition;
|
||||
this.description = description;
|
||||
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
|
||||
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
||||
this.modelVersion = modelVersion;
|
||||
}
|
||||
|
||||
public String getModelId() {
|
||||
|
@ -117,16 +113,12 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return description;
|
||||
}
|
||||
|
||||
public Instant getCreatedTime() {
|
||||
return createdTime;
|
||||
public Instant getCreateTime() {
|
||||
return createTime;
|
||||
}
|
||||
|
||||
public Long getModelVersion() {
|
||||
return modelVersion;
|
||||
}
|
||||
|
||||
public String getModelType() {
|
||||
return modelType;
|
||||
public List<String> getTags() {
|
||||
return tags;
|
||||
}
|
||||
|
||||
public Map<String, Object> getMetadata() {
|
||||
|
@ -156,18 +148,15 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
if (description != null) {
|
||||
builder.field(DESCRIPTION.getPreferredName(), description);
|
||||
}
|
||||
if (createdTime != null) {
|
||||
builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli());
|
||||
}
|
||||
if (modelVersion != null) {
|
||||
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
|
||||
}
|
||||
if (modelType != null) {
|
||||
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
||||
if (createTime != null) {
|
||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||
}
|
||||
if (definition != null) {
|
||||
builder.field(DEFINITION.getPreferredName(), definition);
|
||||
}
|
||||
if (tags != null) {
|
||||
builder.field(TAGS.getPreferredName(), tags);
|
||||
}
|
||||
if (metadata != null) {
|
||||
builder.field(METADATA.getPreferredName(), metadata);
|
||||
}
|
||||
|
@ -189,10 +178,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
Objects.equals(createdBy, that.createdBy) &&
|
||||
Objects.equals(version, that.version) &&
|
||||
Objects.equals(description, that.description) &&
|
||||
Objects.equals(createdTime, that.createdTime) &&
|
||||
Objects.equals(modelVersion, that.modelVersion) &&
|
||||
Objects.equals(modelType, that.modelType) &&
|
||||
Objects.equals(createTime, that.createTime) &&
|
||||
Objects.equals(definition, that.definition) &&
|
||||
Objects.equals(tags, that.tags) &&
|
||||
Objects.equals(metadata, that.metadata);
|
||||
}
|
||||
|
||||
|
@ -201,12 +189,11 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return Objects.hash(modelId,
|
||||
createdBy,
|
||||
version,
|
||||
createdTime,
|
||||
modelType,
|
||||
createTime,
|
||||
definition,
|
||||
description,
|
||||
metadata,
|
||||
modelVersion);
|
||||
tags,
|
||||
metadata);
|
||||
}
|
||||
|
||||
|
||||
|
@ -216,11 +203,10 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private String createdBy;
|
||||
private Version version;
|
||||
private String description;
|
||||
private Instant createdTime;
|
||||
private Long modelVersion;
|
||||
private String modelType;
|
||||
private Instant createTime;
|
||||
private Map<String, Object> metadata;
|
||||
private TrainedModelDefinition.Builder definition;
|
||||
private List<String> tags;
|
||||
private TrainedModelDefinition definition;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -246,18 +232,13 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
private Builder setCreatedTime(Instant createdTime) {
|
||||
this.createdTime = createdTime;
|
||||
private Builder setCreateTime(Instant createTime) {
|
||||
this.createTime = createTime;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setModelVersion(Long modelVersion) {
|
||||
this.modelVersion = modelVersion;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setModelType(String modelType) {
|
||||
this.modelType = modelType;
|
||||
public Builder setTags(List<String> tags) {
|
||||
this.tags = tags;
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -267,6 +248,11 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
}
|
||||
|
||||
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
|
||||
this.definition = definition == null ? null : definition.build();
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDefinition(TrainedModelDefinition definition) {
|
||||
this.definition = definition;
|
||||
return this;
|
||||
}
|
||||
|
@ -277,10 +263,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
createdBy,
|
||||
version,
|
||||
description,
|
||||
createdTime,
|
||||
modelVersion,
|
||||
modelType,
|
||||
definition == null ? null : definition.build(),
|
||||
createTime,
|
||||
definition,
|
||||
tags,
|
||||
metadata);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ import java.util.ArrayList;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
|
||||
|
@ -58,9 +60,9 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
|||
Version.CURRENT,
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
}
|
||||
|
||||
|
|
|
@ -17,28 +17,30 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||
|
||||
public static final String NAME = "trained_model_doc";
|
||||
public static final String NAME = "trained_model_config";
|
||||
|
||||
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||
public static final ParseField CREATED_BY = new ParseField("created_by");
|
||||
public static final ParseField VERSION = new ParseField("version");
|
||||
public static final ParseField DESCRIPTION = new ParseField("description");
|
||||
public static final ParseField CREATED_TIME = new ParseField("created_time");
|
||||
public static final ParseField MODEL_VERSION = new ParseField("model_version");
|
||||
public static final ParseField CREATE_TIME = new ParseField("create_time");
|
||||
public static final ParseField DEFINITION = new ParseField("definition");
|
||||
public static final ParseField MODEL_TYPE = new ParseField("model_type");
|
||||
public static final ParseField TAGS = new ParseField("tags");
|
||||
public static final ParseField METADATA = new ParseField("metadata");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
|
@ -53,16 +55,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
parser.declareString(TrainedModelConfig.Builder::setCreatedBy, CREATED_BY);
|
||||
parser.declareString(TrainedModelConfig.Builder::setVersion, VERSION);
|
||||
parser.declareString(TrainedModelConfig.Builder::setDescription, DESCRIPTION);
|
||||
parser.declareField(TrainedModelConfig.Builder::setCreatedTime,
|
||||
(p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATED_TIME.getPreferredName()),
|
||||
CREATED_TIME,
|
||||
parser.declareField(TrainedModelConfig.Builder::setCreateTime,
|
||||
(p, c) -> TimeUtils.parseTimeFieldToInstant(p, CREATE_TIME.getPreferredName()),
|
||||
CREATE_TIME,
|
||||
ObjectParser.ValueType.VALUE);
|
||||
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
||||
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
||||
parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
||||
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||
parser.declareObject(TrainedModelConfig.Builder::setDefinition,
|
||||
(p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
|
||||
DEFINITION);
|
||||
parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -70,41 +72,35 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
public static String documentId(String modelId, long modelVersion) {
|
||||
return NAME + "-" + modelId + "-" + modelVersion;
|
||||
}
|
||||
|
||||
|
||||
private final String modelId;
|
||||
private final String createdBy;
|
||||
private final Version version;
|
||||
private final String description;
|
||||
private final Instant createdTime;
|
||||
private final long modelVersion;
|
||||
private final String modelType;
|
||||
private final Instant createTime;
|
||||
private final List<String> tags;
|
||||
private final Map<String, Object> metadata;
|
||||
|
||||
// TODO how to reference and store large models that will not be executed in Java???
|
||||
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
|
||||
// TODO Should this be lazily parsed when loading via the index???
|
||||
private final TrainedModelDefinition definition;
|
||||
|
||||
TrainedModelConfig(String modelId,
|
||||
String createdBy,
|
||||
Version version,
|
||||
String description,
|
||||
Instant createdTime,
|
||||
Long modelVersion,
|
||||
String modelType,
|
||||
Instant createTime,
|
||||
TrainedModelDefinition definition,
|
||||
List<String> tags,
|
||||
Map<String, Object> metadata) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
||||
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
|
||||
this.createdTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createdTime, CREATED_TIME).toEpochMilli());
|
||||
this.modelType = ExceptionsHelper.requireNonNull(modelType, MODEL_TYPE);
|
||||
this.createTime = Instant.ofEpochMilli(ExceptionsHelper.requireNonNull(createTime, CREATE_TIME).toEpochMilli());
|
||||
this.definition = definition;
|
||||
this.description = description;
|
||||
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
|
||||
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
||||
this.modelVersion = modelVersion == null ? 0 : modelVersion;
|
||||
}
|
||||
|
||||
public TrainedModelConfig(StreamInput in) throws IOException {
|
||||
|
@ -112,10 +108,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
createdBy = in.readString();
|
||||
version = Version.readVersion(in);
|
||||
description = in.readOptionalString();
|
||||
createdTime = in.readInstant();
|
||||
modelVersion = in.readVLong();
|
||||
modelType = in.readString();
|
||||
createTime = in.readInstant();
|
||||
definition = in.readOptionalWriteable(TrainedModelDefinition::new);
|
||||
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
|
||||
metadata = in.readMap();
|
||||
}
|
||||
|
||||
|
@ -135,16 +130,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return description;
|
||||
}
|
||||
|
||||
public Instant getCreatedTime() {
|
||||
return createdTime;
|
||||
public Instant getCreateTime() {
|
||||
return createTime;
|
||||
}
|
||||
|
||||
public long getModelVersion() {
|
||||
return modelVersion;
|
||||
}
|
||||
|
||||
public String getModelType() {
|
||||
return modelType;
|
||||
public List<String> getTags() {
|
||||
return tags;
|
||||
}
|
||||
|
||||
public Map<String, Object> getMetadata() {
|
||||
|
@ -166,10 +157,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
out.writeString(createdBy);
|
||||
Version.writeVersion(version, out);
|
||||
out.writeOptionalString(description);
|
||||
out.writeInstant(createdTime);
|
||||
out.writeVLong(modelVersion);
|
||||
out.writeString(modelType);
|
||||
out.writeInstant(createTime);
|
||||
out.writeOptionalWriteable(definition);
|
||||
out.writeCollection(tags, StreamOutput::writeString);
|
||||
out.writeMap(metadata);
|
||||
}
|
||||
|
||||
|
@ -182,15 +172,17 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
if (description != null) {
|
||||
builder.field(DESCRIPTION.getPreferredName(), description);
|
||||
}
|
||||
builder.timeField(CREATED_TIME.getPreferredName(), CREATED_TIME.getPreferredName() + "_string", createdTime.toEpochMilli());
|
||||
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
|
||||
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||
if (definition != null) {
|
||||
builder.field(DEFINITION.getPreferredName(), definition);
|
||||
}
|
||||
builder.field(TAGS.getPreferredName(), tags);
|
||||
if (metadata != null) {
|
||||
builder.field(METADATA.getPreferredName(), metadata);
|
||||
}
|
||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -209,10 +201,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
Objects.equals(createdBy, that.createdBy) &&
|
||||
Objects.equals(version, that.version) &&
|
||||
Objects.equals(description, that.description) &&
|
||||
Objects.equals(createdTime, that.createdTime) &&
|
||||
Objects.equals(modelVersion, that.modelVersion) &&
|
||||
Objects.equals(modelType, that.modelType) &&
|
||||
Objects.equals(createTime, that.createTime) &&
|
||||
Objects.equals(definition, that.definition) &&
|
||||
Objects.equals(tags, that.tags) &&
|
||||
Objects.equals(metadata, that.metadata);
|
||||
}
|
||||
|
||||
|
@ -221,12 +212,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return Objects.hash(modelId,
|
||||
createdBy,
|
||||
version,
|
||||
createdTime,
|
||||
modelType,
|
||||
createTime,
|
||||
definition,
|
||||
description,
|
||||
metadata,
|
||||
modelVersion);
|
||||
tags,
|
||||
metadata);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
@ -235,11 +225,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private String createdBy;
|
||||
private Version version;
|
||||
private String description;
|
||||
private Instant createdTime;
|
||||
private Long modelVersion;
|
||||
private String modelType;
|
||||
private Instant createTime;
|
||||
private List<String> tags = Collections.emptyList();
|
||||
private Map<String, Object> metadata;
|
||||
private TrainedModelDefinition.Builder definition;
|
||||
private TrainedModelDefinition definition;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -265,18 +254,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setCreatedTime(Instant createdTime) {
|
||||
this.createdTime = createdTime;
|
||||
public Builder setCreateTime(Instant createTime) {
|
||||
this.createTime = createTime;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setModelVersion(Long modelVersion) {
|
||||
this.modelVersion = modelVersion;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setModelType(String modelType) {
|
||||
this.modelType = modelType;
|
||||
public Builder setTags(List<String> tags) {
|
||||
this.tags = ExceptionsHelper.requireNonNull(tags, TAGS);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -286,6 +270,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
}
|
||||
|
||||
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
|
||||
this.definition = definition.build();
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDefinition(TrainedModelDefinition definition) {
|
||||
this.definition = definition;
|
||||
return this;
|
||||
}
|
||||
|
@ -316,9 +305,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
CREATED_BY.getPreferredName());
|
||||
}
|
||||
|
||||
if (createdTime != null) {
|
||||
if (createTime != null) {
|
||||
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
|
||||
CREATED_TIME.getPreferredName());
|
||||
CREATE_TIME.getPreferredName());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -328,23 +317,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
createdBy,
|
||||
version,
|
||||
description,
|
||||
createdTime,
|
||||
modelVersion,
|
||||
modelType,
|
||||
definition == null ? null : definition.build(),
|
||||
metadata);
|
||||
}
|
||||
|
||||
public TrainedModelConfig build(Version version) {
|
||||
return new TrainedModelConfig(
|
||||
modelId,
|
||||
createdBy,
|
||||
version,
|
||||
description,
|
||||
Instant.now(),
|
||||
modelVersion,
|
||||
modelType,
|
||||
definition == null ? null : definition.build(),
|
||||
createTime == null ? Instant.now() : createTime,
|
||||
definition,
|
||||
tags,
|
||||
metadata);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.persistence;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
|
||||
/**
|
||||
* Class containing the index constants so that the index version, name, and prefix are available to a wider audience.
|
||||
*/
|
||||
|
@ -14,6 +16,7 @@ public final class InferenceIndexConstants {
|
|||
public static final String INDEX_NAME_PREFIX = ".ml-inference-";
|
||||
public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
|
||||
public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;
|
||||
public static final ParseField DOC_TYPE = new ParseField("doc_type");
|
||||
|
||||
private InferenceIndexConstants() {}
|
||||
|
||||
|
|
|
@ -80,11 +80,11 @@ public final class Messages {
|
|||
public static final String INVALID_GROUP = "Invalid group id ''{0}''; must be non-empty string and may contain lowercase alphanumeric" +
|
||||
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
|
||||
|
||||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] with version [{1}] already exists";
|
||||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
||||
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
|
||||
"Failed to serialize the trained model [{0}] with version [{1}] for storage";
|
||||
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}] with version [{1}]";
|
||||
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
||||
|
||||
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
|
||||
public static final String JOB_AUDIT_CREATED = "Job created";
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.junit.Before;
|
|||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
@ -29,7 +30,6 @@ import java.util.stream.IntStream;
|
|||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
||||
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
|
||||
|
||||
private boolean lenient;
|
||||
|
@ -56,15 +56,15 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
|
||||
@Override
|
||||
protected TrainedModelConfig createTestInstance() {
|
||||
List<String> tags = Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false));
|
||||
return new TrainedModelConfig(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
Version.CURRENT,
|
||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||
tags,
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
}
|
||||
|
||||
|
@ -116,9 +116,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setCreatedTime(Instant.now())
|
||||
.setCreateTime(Instant.now())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
|
|
|
@ -200,6 +200,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationP
|
|||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.job.JobManager;
|
||||
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
|
||||
import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier;
|
||||
|
@ -523,7 +524,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
client,
|
||||
clusterService);
|
||||
normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController, clusterService);
|
||||
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService);
|
||||
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService,
|
||||
xContentRegistry);
|
||||
memoryEstimationProcessFactory =
|
||||
new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService);
|
||||
} catch (IOException e) {
|
||||
|
@ -566,9 +568,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
System::currentTimeMillis, anomalyDetectionAuditor, autodetectProcessManager);
|
||||
this.datafeedManager.set(datafeedManager);
|
||||
|
||||
// Inference components
|
||||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
|
||||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
|
||||
dataFrameAnalyticsAuditor);
|
||||
dataFrameAnalyticsAuditor, trainedModelProvider);
|
||||
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
|
||||
new MemoryUsageEstimationProcessManager(
|
||||
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);
|
||||
|
@ -614,7 +619,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
|
|||
analyticsProcessManager,
|
||||
memoryEstimationProcessManager,
|
||||
dataFrameAnalyticsConfigProvider,
|
||||
nativeStorageProvider
|
||||
nativeStorageProvider,
|
||||
trainedModelProvider
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.xpack.ml.process.AbstractNativeProcess;
|
||||
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
|
||||
|
||||
|
@ -26,10 +27,11 @@ abstract class AbstractNativeAnalyticsProcess<Result> extends AbstractNativeProc
|
|||
protected AbstractNativeAnalyticsProcess(String name, ConstructingObjectParser<Result, Void> resultParser, String jobId,
|
||||
InputStream logStream, OutputStream processInStream,
|
||||
InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields,
|
||||
List<Path> filesToDelete, Consumer<String> onProcessCrash) {
|
||||
List<Path> filesToDelete, Consumer<String> onProcessCrash,
|
||||
NamedXContentRegistry namedXContentRegistry) {
|
||||
super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash);
|
||||
this.name = Objects.requireNonNull(name);
|
||||
this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser));
|
||||
this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser), namedXContentRegistry);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact
|
|||
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -57,15 +58,18 @@ public class AnalyticsProcessManager {
|
|||
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
|
||||
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
||||
private final DataFrameAnalyticsAuditor auditor;
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
|
||||
public AnalyticsProcessManager(Client client,
|
||||
ThreadPool threadPool,
|
||||
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||
DataFrameAnalyticsAuditor auditor) {
|
||||
DataFrameAnalyticsAuditor auditor,
|
||||
TrainedModelProvider trainedModelProvider) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.threadPool = Objects.requireNonNull(threadPool);
|
||||
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
|
||||
this.auditor = Objects.requireNonNull(auditor);
|
||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||
}
|
||||
|
||||
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
|
||||
|
@ -356,7 +360,8 @@ public class AnalyticsProcessManager {
|
|||
process = createProcess(task, config, analyticsProcessConfig, state);
|
||||
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
|
||||
dataExtractorFactory.newExtractor(true));
|
||||
resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker());
|
||||
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
|
||||
trainedModelProvider, auditor);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -8,33 +8,51 @@ package org.elasticsearch.xpack.ml.dataframe.process;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.LatchedActionListener;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
public class AnalyticsResultProcessor {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
|
||||
|
||||
private final String dataFrameAnalyticsId;
|
||||
private final DataFrameAnalyticsConfig analytics;
|
||||
private final DataFrameRowsJoiner dataFrameRowsJoiner;
|
||||
private final Supplier<Boolean> isProcessKilled;
|
||||
private final ProgressTracker progressTracker;
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
private final DataFrameAnalyticsAuditor auditor;
|
||||
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
||||
private volatile String failure;
|
||||
|
||||
public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier<Boolean> isProcessKilled,
|
||||
ProgressTracker progressTracker) {
|
||||
this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId);
|
||||
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
||||
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
|
||||
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor) {
|
||||
this.analytics = Objects.requireNonNull(analytics);
|
||||
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
||||
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
|
||||
this.progressTracker = Objects.requireNonNull(progressTracker);
|
||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||
this.auditor = Objects.requireNonNull(auditor);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
|
@ -47,7 +65,7 @@ public class AnalyticsResultProcessor {
|
|||
completionLatch.await();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", dataFrameAnalyticsId), e);
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,7 +93,7 @@ public class AnalyticsResultProcessor {
|
|||
if (isProcessKilled.get()) {
|
||||
// No need to log error as it's due to stopping
|
||||
} else {
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", dataFrameAnalyticsId), e);
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
|
||||
failure = "error parsing data frame analytics output: [" + e.getMessage() + "]";
|
||||
}
|
||||
} finally {
|
||||
|
@ -93,5 +111,60 @@ public class AnalyticsResultProcessor {
|
|||
if (progressPercent != null) {
|
||||
progressTracker.analyzingPercent.set(progressPercent);
|
||||
}
|
||||
TrainedModelDefinition inferenceModel = result.getInferenceModel();
|
||||
if (inferenceModel != null) {
|
||||
createAndIndexInferenceModel(inferenceModel);
|
||||
}
|
||||
}
|
||||
|
||||
private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) {
|
||||
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
|
||||
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
|
||||
|
||||
try {
|
||||
if (latch.await(30, TimeUnit.SECONDS) == false) {
|
||||
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e);
|
||||
}
|
||||
}
|
||||
|
||||
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition inferenceModel) {
|
||||
Instant createTime = Instant.now();
|
||||
return TrainedModelConfig.builder()
|
||||
.setModelId(analytics.getId() + "-" + createTime.toEpochMilli())
|
||||
.setCreatedBy("data-frame-analytics")
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(createTime)
|
||||
.setTags(Collections.singletonList(analytics.getId()))
|
||||
.setDescription(analytics.getDescription())
|
||||
.setMetadata(Collections.singletonMap("analytics_config",
|
||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||
.setDefinition(inferenceModel)
|
||||
.build();
|
||||
}
|
||||
|
||||
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
ActionListener<Boolean> storeListener = ActionListener.wrap(
|
||||
aBoolean -> {
|
||||
if (aBoolean == false) {
|
||||
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
|
||||
} else {
|
||||
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
|
||||
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
|
||||
}
|
||||
},
|
||||
e -> {
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(),
|
||||
trainedModelConfig.getModelId()), e);
|
||||
auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId()
|
||||
+ "]; error message [" + e.getMessage() + "]");
|
||||
}
|
||||
);
|
||||
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
|
||||
return latch;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,15 +6,14 @@
|
|||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
|
||||
import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Consumer;
|
||||
|
@ -23,14 +22,14 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
|
|||
|
||||
private static final String NAME = "analytics";
|
||||
|
||||
private final ProcessResultsParser<AnalyticsResult> resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER);
|
||||
private final AnalyticsProcessConfig config;
|
||||
|
||||
protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream,
|
||||
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
|
||||
Consumer<String> onProcessCrash, AnalyticsProcessConfig config) {
|
||||
Consumer<String> onProcessCrash, AnalyticsProcessConfig config,
|
||||
NamedXContentRegistry namedXContentRegistry) {
|
||||
super(NAME, AnalyticsResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields,
|
||||
filesToDelete, onProcessCrash);
|
||||
filesToDelete, onProcessCrash, namedXContentRegistry);
|
||||
this.config = Objects.requireNonNull(config);
|
||||
}
|
||||
|
||||
|
@ -49,11 +48,6 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
|
|||
new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<AnalyticsResult> readAnalyticsResults() {
|
||||
return resultsParser.parseResults(processOutStream());
|
||||
}
|
||||
|
||||
@Override
|
||||
public AnalyticsProcessConfig getConfig() {
|
||||
return config;
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.Nullable;
|
|||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.core.internal.io.IOUtils;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
|
@ -42,12 +43,15 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
|
|||
private final Client client;
|
||||
private final Environment env;
|
||||
private final NativeController nativeController;
|
||||
private final NamedXContentRegistry namedXContentRegistry;
|
||||
private volatile Duration processConnectTimeout;
|
||||
|
||||
public NativeAnalyticsProcessFactory(Environment env, Client client, NativeController nativeController, ClusterService clusterService) {
|
||||
public NativeAnalyticsProcessFactory(Environment env, Client client, NativeController nativeController, ClusterService clusterService,
|
||||
NamedXContentRegistry namedXContentRegistry) {
|
||||
this.env = Objects.requireNonNull(env);
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.nativeController = Objects.requireNonNull(nativeController);
|
||||
this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry);
|
||||
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings()));
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT,
|
||||
this::setProcessConnectTimeout);
|
||||
|
@ -73,7 +77,8 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
|
|||
|
||||
NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(),
|
||||
processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(),
|
||||
processPipes.getRestoreStream().orElse(null), numberOfFields, filesToDelete, onProcessCrash, analyticsProcessConfig);
|
||||
processPipes.getRestoreStream().orElse(null), numberOfFields, filesToDelete, onProcessCrash, analyticsProcessConfig,
|
||||
namedXContentRegistry);
|
||||
|
||||
try {
|
||||
startProcess(config, executorService, processPipes, analyticsProcess);
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
@ -23,7 +24,7 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
|
|||
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
|
||||
Consumer<String> onProcessCrash) {
|
||||
super(NAME, MemoryUsageEstimationResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream,
|
||||
numberOfFields, filesToDelete, onProcessCrash);
|
||||
numberOfFields, filesToDelete, onProcessCrash, NamedXContentRegistry.EMPTY);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
@ -20,21 +21,26 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
public static final ParseField TYPE = new ParseField("analytics_result");
|
||||
|
||||
public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
|
||||
public static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
||||
|
||||
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
|
||||
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1]));
|
||||
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
||||
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
|
||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.STRICT_PARSER.apply(p, null).build(),
|
||||
INFERENCE_MODEL);
|
||||
}
|
||||
|
||||
private final RowResults rowResults;
|
||||
private final Integer progressPercent;
|
||||
private final TrainedModelDefinition inferenceModel;
|
||||
|
||||
public AnalyticsResult(RowResults rowResults, Integer progressPercent) {
|
||||
public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition inferenceModel) {
|
||||
this.rowResults = rowResults;
|
||||
this.progressPercent = progressPercent;
|
||||
this.inferenceModel = inferenceModel;
|
||||
}
|
||||
|
||||
public RowResults getRowResults() {
|
||||
|
@ -45,6 +51,10 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
return progressPercent;
|
||||
}
|
||||
|
||||
public TrainedModelDefinition getInferenceModel() {
|
||||
return inferenceModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -54,6 +64,9 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
if (progressPercent != null) {
|
||||
builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
|
||||
}
|
||||
if (inferenceModel != null) {
|
||||
builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -68,11 +81,13 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
}
|
||||
|
||||
AnalyticsResult that = (AnalyticsResult) other;
|
||||
return Objects.equals(rowResults, that.rowResults) && Objects.equals(progressPercent, that.progressPercent);
|
||||
return Objects.equals(rowResults, that.rowResults)
|
||||
&& Objects.equals(progressPercent, that.progressPercent)
|
||||
&& Objects.equals(inferenceModel, that.inferenceModel);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(rowResults, progressPercent);
|
||||
return Objects.hash(rowResults, progressPercent, inferenceModel);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings;
|
|||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -23,13 +24,11 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappi
|
|||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE;
|
||||
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.addMetaInformation;
|
||||
|
||||
|
||||
/**
|
||||
* Changelog of internal index versions
|
||||
*
|
||||
|
@ -68,6 +67,12 @@ public final class InferenceInternalIndex {
|
|||
builder.field(DYNAMIC, "false");
|
||||
|
||||
builder.startObject(PROPERTIES);
|
||||
|
||||
// Add the doc_type field
|
||||
builder.startObject(InferenceIndexConstants.DOC_TYPE.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject();
|
||||
|
||||
addInferenceDocFields(builder);
|
||||
return builder.endObject()
|
||||
.endObject()
|
||||
|
@ -87,16 +92,13 @@ public final class InferenceInternalIndex {
|
|||
.startObject(TrainedModelConfig.DESCRIPTION.getPreferredName())
|
||||
.field(TYPE, TEXT)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.CREATED_TIME.getPreferredName())
|
||||
.startObject(TrainedModelConfig.CREATE_TIME.getPreferredName())
|
||||
.field(TYPE, DATE)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.MODEL_VERSION.getPreferredName())
|
||||
.field(TYPE, LONG)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.DEFINITION.getPreferredName())
|
||||
.field(ENABLED, false)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.MODEL_TYPE.getPreferredName())
|
||||
.startObject(TrainedModelConfig.TAGS.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.startObject(TrainedModelConfig.METADATA.getPreferredName())
|
||||
|
|
|
@ -37,9 +37,11 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
|
@ -57,26 +59,23 @@ public class TrainedModelProvider {
|
|||
|
||||
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
|
||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
XContentBuilder source = trainedModelConfig.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
XContentBuilder source = trainedModelConfig.toXContent(builder,
|
||||
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME)
|
||||
.opType(DocWriteRequest.OpType.CREATE)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.id(TrainedModelConfig.documentId(trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion()))
|
||||
.id(trainedModelConfig.getModelId())
|
||||
.source(source);
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest,
|
||||
ActionListener.wrap(
|
||||
r -> listener.onResponse(true),
|
||||
e -> {
|
||||
logger.error(
|
||||
new ParameterizedMessage("[{}][{}] failed to store trained model for inference",
|
||||
trainedModelConfig.getModelId(),
|
||||
trainedModelConfig.getModelVersion()),
|
||||
e);
|
||||
logger.error(new ParameterizedMessage(
|
||||
"[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e);
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS,
|
||||
trainedModelConfig.getModelId(), trainedModelConfig.getModelVersion())));
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||
} else {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
||||
|
@ -93,10 +92,10 @@ public class TrainedModelProvider {
|
|||
}
|
||||
}
|
||||
|
||||
public void getTrainedModel(String modelId, long modelVersion, ActionListener<TrainedModelConfig> listener) {
|
||||
public void getTrainedModel(String modelId, ActionListener<TrainedModelConfig> listener) {
|
||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
||||
.idsQuery()
|
||||
.addIds(TrainedModelConfig.documentId(modelId, modelVersion)));
|
||||
.addIds(modelId));
|
||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.setQuery(queryBuilder)
|
||||
// use sort to get the last
|
||||
|
@ -109,11 +108,11 @@ public class TrainedModelProvider {
|
|||
searchResponse -> {
|
||||
if (searchResponse.getHits().getHits().length == 0) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion)));
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return;
|
||||
}
|
||||
BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef();
|
||||
parseInferenceDocLenientlyFromSource(source, modelId, modelVersion, listener);
|
||||
parseInferenceDocLenientlyFromSource(source, modelId, listener);
|
||||
},
|
||||
listener::onFailure));
|
||||
}
|
||||
|
@ -121,14 +120,13 @@ public class TrainedModelProvider {
|
|||
|
||||
private void parseInferenceDocLenientlyFromSource(BytesReference source,
|
||||
String modelId,
|
||||
long modelVersion,
|
||||
ActionListener<TrainedModelConfig> modelListener) {
|
||||
try (InputStream stream = source.streamInput();
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
|
||||
modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build());
|
||||
} catch (Exception e) {
|
||||
logger.error(new ParameterizedMessage("[{}][{}] failed to parse model", modelId, modelVersion), e);
|
||||
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
|
||||
modelListener.onFailure(e);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,14 +12,15 @@ import org.elasticsearch.cluster.service.ClusterService;
|
|||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.core.internal.io.IOUtils;
|
||||
import org.elasticsearch.env.Environment;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
|
||||
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
|
||||
import org.elasticsearch.xpack.ml.job.results.AutodetectResult;
|
||||
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
|
||||
import org.elasticsearch.xpack.ml.process.NativeController;
|
||||
import org.elasticsearch.xpack.ml.process.ProcessPipes;
|
||||
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
|
||||
|
@ -78,7 +79,8 @@ public class NativeAutodetectProcessFactory implements AutodetectProcessFactory
|
|||
int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1;
|
||||
|
||||
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, job.getId());
|
||||
ProcessResultsParser<AutodetectResult> resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER);
|
||||
ProcessResultsParser<AutodetectResult> resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER,
|
||||
NamedXContentRegistry.EMPTY);
|
||||
NativeAutodetectProcess autodetect = new NativeAutodetectProcess(
|
||||
job.getId(), processPipes.getLogStream().get(), processPipes.getProcessInStream().get(),
|
||||
processPipes.getProcessOutStream().get(), processPipes.getRestoreStream().orElse(null), numberOfFields,
|
||||
|
|
|
@ -32,15 +32,17 @@ public class ProcessResultsParser<T> {
|
|||
private static final Logger logger = LogManager.getLogger(ProcessResultsParser.class);
|
||||
|
||||
private final ConstructingObjectParser<T, Void> resultParser;
|
||||
private final NamedXContentRegistry namedXContentRegistry;
|
||||
|
||||
public ProcessResultsParser(ConstructingObjectParser<T, Void> resultParser) {
|
||||
public ProcessResultsParser(ConstructingObjectParser<T, Void> resultParser, NamedXContentRegistry namedXContentRegistry) {
|
||||
this.resultParser = Objects.requireNonNull(resultParser);
|
||||
this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry);
|
||||
}
|
||||
|
||||
public Iterator<T> parseResults(InputStream in) throws ElasticsearchParseException {
|
||||
try {
|
||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, in);
|
||||
.createParser(namedXContentRegistry, LoggingDeprecationHandler.INSTANCE, in);
|
||||
XContentParser.Token token = parser.nextToken();
|
||||
// if start of an array ignore it, we expect an array of results
|
||||
if (token != XContentParser.Token.START_ARRAY) {
|
||||
|
|
|
@ -5,21 +5,42 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InOrder;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
@ -28,16 +49,29 @@ import static org.mockito.Mockito.when;
|
|||
public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||
|
||||
private static final String JOB_ID = "analytics-result-processor-tests";
|
||||
private static final String JOB_DESCRIPTION = "This describes the job of these tests";
|
||||
|
||||
private AnalyticsProcess<AnalyticsResult> process;
|
||||
private DataFrameRowsJoiner dataFrameRowsJoiner;
|
||||
private ProgressTracker progressTracker = new ProgressTracker();
|
||||
private TrainedModelProvider trainedModelProvider;
|
||||
private DataFrameAnalyticsAuditor auditor;
|
||||
private DataFrameAnalyticsConfig analyticsConfig;
|
||||
|
||||
@Before
|
||||
@SuppressWarnings("unchecked")
|
||||
public void setUpMocks() {
|
||||
process = mock(AnalyticsProcess.class);
|
||||
dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class);
|
||||
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||
auditor = mock(DataFrameAnalyticsAuditor.class);
|
||||
analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(JOB_ID)
|
||||
.setDescription(JOB_DESCRIPTION)
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null))
|
||||
.setDest(new DataFrameAnalyticsDest("my_dest", null))
|
||||
.setAnalysis(new Regression("foo"))
|
||||
.build();
|
||||
}
|
||||
|
||||
public void testProcess_GivenNoResults() {
|
||||
|
@ -54,7 +88,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
|
||||
public void testProcess_GivenEmptyResults() {
|
||||
givenDataFrameRows(2);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50, null), new AnalyticsResult(null, 100, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
@ -69,7 +103,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
givenDataFrameRows(2);
|
||||
RowResults rowResults1 = mock(RowResults.class);
|
||||
RowResults rowResults2 = mock(RowResults.class);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100)));
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
|
@ -82,6 +116,71 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
|
||||
givenDataFrameRows(0);
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
||||
storeListener.onResponse(true);
|
||||
return null;
|
||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||
|
||||
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
resultProcessor.awaitForCompletion();
|
||||
|
||||
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
|
||||
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
|
||||
|
||||
TrainedModelConfig storedModel = storedModelCaptor.getValue();
|
||||
assertThat(storedModel.getModelId(), containsString(JOB_ID));
|
||||
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
|
||||
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));
|
||||
assertThat(storedModel.getTags(), contains(JOB_ID));
|
||||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||
assertThat(storedModel.getDefinition(), equalTo(inferenceModel));
|
||||
Map<String, Object> metadata = storedModel.getMetadata();
|
||||
assertThat(metadata.size(), equalTo(1));
|
||||
assertThat(metadata, hasKey("analytics_config"));
|
||||
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
|
||||
true);
|
||||
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
|
||||
|
||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
||||
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
|
||||
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
|
||||
Mockito.verifyNoMoreInteractions(auditor);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testProcess_GivenInferenceModelFailedToStore() {
|
||||
givenDataFrameRows(0);
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
||||
storeListener.onFailure(new RuntimeException("some failure"));
|
||||
return null;
|
||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||
|
||||
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
resultProcessor.awaitForCompletion();
|
||||
|
||||
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
|
||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
||||
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
|
||||
assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID));
|
||||
assertThat(auditCaptor.getValue(), containsString("[some failure]"));
|
||||
Mockito.verifyNoMoreInteractions(auditor);
|
||||
}
|
||||
|
||||
private void givenProcessResults(List<AnalyticsResult> results) {
|
||||
when(process.readAnalyticsResults()).thenReturn(results.iterator());
|
||||
}
|
||||
|
@ -93,6 +192,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private AnalyticsResultProcessor createResultProcessor() {
|
||||
return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, progressTracker);
|
||||
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider,
|
||||
auditor);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,24 +5,45 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.results;
|
||||
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResult> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AnalyticsResult createTestInstance() {
|
||||
RowResults rowResults = null;
|
||||
Integer progressPercent = null;
|
||||
TrainedModelDefinition inferenceModel = null;
|
||||
if (randomBoolean()) {
|
||||
rowResults = RowResultsTests.createRandom();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
progressPercent = randomIntBetween(0, 100);
|
||||
}
|
||||
return new AnalyticsResult(rowResults, progressPercent);
|
||||
if (randomBoolean()) {
|
||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
|
||||
}
|
||||
return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -39,7 +39,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
|
||||
public void testPutTrainedModelConfig() throws Exception {
|
||||
String modelId = "test-put-trained-model-config";
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
|
@ -50,7 +50,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
|
||||
public void testPutTrainedModelConfigThatAlreadyExists() throws Exception {
|
||||
String modelId = "test-put-trained-model-config-exists";
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
|
@ -61,12 +61,12 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(),
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId, 0)));
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, modelId)));
|
||||
}
|
||||
|
||||
public void testGetTrainedModelConfig() throws Exception {
|
||||
String modelId = "test-get-trained-model-config";
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId, 0);
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
|
@ -75,7 +75,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder);
|
||||
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||
assertThat(getConfigHolder.get(), equalTo(config));
|
||||
}
|
||||
|
@ -84,21 +84,20 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
String modelId = "test-get-missing-trained-model-config";
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, 0, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(),
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, 0)));
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
}
|
||||
|
||||
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
|
||||
private static TrainedModelConfig buildTrainedModelConfig(String modelId) {
|
||||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId)
|
||||
.setModelType("binary_decision_tree")
|
||||
.setModelVersion(modelVersion)
|
||||
.build(Version.CURRENT);
|
||||
.setVersion(Version.CURRENT)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.job.process.autodetect;
|
||||
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
|
||||
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
|
||||
|
@ -61,7 +62,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
|
|||
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
|
||||
mock(OutputStream.class), outputStream, mock(OutputStream.class),
|
||||
NUMBER_FIELDS, null,
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
|
||||
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
|
||||
|
||||
ZonedDateTime startTime = process.getProcessStartTime();
|
||||
|
@ -84,7 +85,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
|
|||
ByteArrayOutputStream bos = new ByteArrayOutputStream(1024);
|
||||
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
|
||||
bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
|
||||
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
|
||||
|
||||
process.writeRecord(record);
|
||||
|
@ -119,7 +120,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
|
|||
ByteArrayOutputStream bos = new ByteArrayOutputStream(AutodetectControlMsgWriter.FLUSH_SPACES_LENGTH + 1024);
|
||||
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
|
||||
bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
|
||||
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
|
||||
|
||||
FlushJobParams params = FlushJobParams.builder().build();
|
||||
|
@ -153,7 +154,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
|
|||
|
||||
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
|
||||
processInStream, processOutStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
|
||||
new ProcessResultsParser<AutodetectResult>(AutodetectResult.PARSER), mock(Consumer.class))) {
|
||||
new ProcessResultsParser<AutodetectResult>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
|
||||
|
||||
process.consumeAndCloseOutputStream();
|
||||
assertThat(processOutStream.available(), equalTo(0));
|
||||
|
@ -169,7 +170,7 @@ public class NativeAutodetectProcessTests extends ESTestCase {
|
|||
ByteArrayOutputStream bos = new ByteArrayOutputStream(1024);
|
||||
try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream,
|
||||
bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(),
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Consumer.class))) {
|
||||
new ProcessResultsParser<>(AutodetectResult.PARSER, NamedXContentRegistry.EMPTY), mock(Consumer.class))) {
|
||||
process.start(executorService, mock(IndexingStateProcessor.class), mock(InputStream.class));
|
||||
|
||||
writeFunction.accept(process);
|
||||
|
|
|
@ -9,6 +9,7 @@ import com.google.common.base.Charsets;
|
|||
import org.elasticsearch.ElasticsearchParseException;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
|
@ -28,7 +29,7 @@ public class ProcessResultsParserTests extends ESTestCase {
|
|||
public void testParse_GivenEmptyArray() throws IOException {
|
||||
String json = "[]";
|
||||
try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) {
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
|
||||
assertFalse(parser.parseResults(inputStream).hasNext());
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +37,7 @@ public class ProcessResultsParserTests extends ESTestCase {
|
|||
public void testParse_GivenUnknownObject() throws IOException {
|
||||
String json = "[{\"unknown\":{\"id\": 18}}]";
|
||||
try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) {
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
|
||||
XContentParseException e = expectThrows(XContentParseException.class,
|
||||
() -> parser.parseResults(inputStream).forEachRemaining(a -> {
|
||||
}));
|
||||
|
@ -47,7 +48,7 @@ public class ProcessResultsParserTests extends ESTestCase {
|
|||
public void testParse_GivenArrayContainsAnotherArray() throws IOException {
|
||||
String json = "[[]]";
|
||||
try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) {
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
|
||||
ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class,
|
||||
() -> parser.parseResults(inputStream).forEachRemaining(a -> {
|
||||
}));
|
||||
|
@ -60,7 +61,7 @@ public class ProcessResultsParserTests extends ESTestCase {
|
|||
+ " {\"field_1\": \"c\", \"field_2\": 3.0}]";
|
||||
try (InputStream inputStream = new ByteArrayInputStream(input.getBytes(Charsets.UTF_8))) {
|
||||
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER);
|
||||
ProcessResultsParser<TestResult> parser = new ProcessResultsParser<>(TestResult.PARSER, NamedXContentRegistry.EMPTY);
|
||||
Iterator<TestResult> testResultIterator = parser.parseResults(inputStream);
|
||||
|
||||
List<TestResult> parsedResults = new ArrayList<>();
|
||||
|
|
Loading…
Reference in New Issue