* [ML][Inference] separating definition and config object storage (#48651) This separates out the `definition` object from being stored within the configuration object in the index. This allows us to gather the config object without decompressing a potentially large definition. Additionally, `input` is moved to the TrainedModelConfig object and out of the definition. This is so the trained input fields are accessible outside the potentially large model definition.
This commit is contained in:
parent
0476f014bc
commit
c9ead80c31
|
@ -46,6 +46,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
public static final ParseField DEFINITION = new ParseField("definition");
|
public static final ParseField DEFINITION = new ParseField("definition");
|
||||||
public static final ParseField TAGS = new ParseField("tags");
|
public static final ParseField TAGS = new ParseField("tags");
|
||||||
public static final ParseField METADATA = new ParseField("metadata");
|
public static final ParseField METADATA = new ParseField("metadata");
|
||||||
|
public static final ParseField INPUT = new ParseField("input");
|
||||||
|
|
||||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||||
true,
|
true,
|
||||||
|
@ -64,6 +65,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
DEFINITION);
|
DEFINITION);
|
||||||
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
||||||
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||||
|
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
|
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
|
||||||
|
@ -78,6 +80,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
private final TrainedModelDefinition definition;
|
private final TrainedModelDefinition definition;
|
||||||
private final List<String> tags;
|
private final List<String> tags;
|
||||||
private final Map<String, Object> metadata;
|
private final Map<String, Object> metadata;
|
||||||
|
private final TrainedModelInput input;
|
||||||
|
|
||||||
TrainedModelConfig(String modelId,
|
TrainedModelConfig(String modelId,
|
||||||
String createdBy,
|
String createdBy,
|
||||||
|
@ -86,7 +89,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
Instant createTime,
|
Instant createTime,
|
||||||
TrainedModelDefinition definition,
|
TrainedModelDefinition definition,
|
||||||
List<String> tags,
|
List<String> tags,
|
||||||
Map<String, Object> metadata) {
|
Map<String, Object> metadata,
|
||||||
|
TrainedModelInput input) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
this.createdBy = createdBy;
|
this.createdBy = createdBy;
|
||||||
this.version = version;
|
this.version = version;
|
||||||
|
@ -95,6 +99,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
this.description = description;
|
this.description = description;
|
||||||
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
|
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
|
||||||
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
||||||
|
this.input = input;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getModelId() {
|
public String getModelId() {
|
||||||
|
@ -129,6 +134,10 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
return definition;
|
return definition;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TrainedModelInput getInput() {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new Builder();
|
return new Builder();
|
||||||
}
|
}
|
||||||
|
@ -160,6 +169,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
if (metadata != null) {
|
if (metadata != null) {
|
||||||
builder.field(METADATA.getPreferredName(), metadata);
|
builder.field(METADATA.getPreferredName(), metadata);
|
||||||
}
|
}
|
||||||
|
if (input != null) {
|
||||||
|
builder.field(INPUT.getPreferredName(), input);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -181,6 +193,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
Objects.equals(createTime, that.createTime) &&
|
Objects.equals(createTime, that.createTime) &&
|
||||||
Objects.equals(definition, that.definition) &&
|
Objects.equals(definition, that.definition) &&
|
||||||
Objects.equals(tags, that.tags) &&
|
Objects.equals(tags, that.tags) &&
|
||||||
|
Objects.equals(input, that.input) &&
|
||||||
Objects.equals(metadata, that.metadata);
|
Objects.equals(metadata, that.metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,7 +206,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
definition,
|
definition,
|
||||||
description,
|
description,
|
||||||
tags,
|
tags,
|
||||||
metadata);
|
metadata,
|
||||||
|
input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -207,6 +221,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
private Map<String, Object> metadata;
|
private Map<String, Object> metadata;
|
||||||
private List<String> tags;
|
private List<String> tags;
|
||||||
private TrainedModelDefinition definition;
|
private TrainedModelDefinition definition;
|
||||||
|
private TrainedModelInput input;
|
||||||
|
|
||||||
public Builder setModelId(String modelId) {
|
public Builder setModelId(String modelId) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
|
@ -257,6 +272,11 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setInput(TrainedModelInput input) {
|
||||||
|
this.input = input;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public TrainedModelConfig build() {
|
public TrainedModelConfig build() {
|
||||||
return new TrainedModelConfig(
|
return new TrainedModelConfig(
|
||||||
modelId,
|
modelId,
|
||||||
|
@ -266,7 +286,9 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
createTime,
|
createTime,
|
||||||
definition,
|
definition,
|
||||||
tags,
|
tags,
|
||||||
metadata);
|
metadata,
|
||||||
|
input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
||||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
@ -39,7 +38,6 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
|
|
||||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||||
public static final ParseField INPUT = new ParseField("input");
|
|
||||||
|
|
||||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||||
true,
|
true,
|
||||||
|
@ -53,7 +51,6 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
|
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
|
||||||
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
|
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
|
||||||
PREPROCESSORS);
|
PREPROCESSORS);
|
||||||
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
|
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
|
||||||
|
@ -62,12 +59,10 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
|
|
||||||
private final TrainedModel trainedModel;
|
private final TrainedModel trainedModel;
|
||||||
private final List<PreProcessor> preProcessors;
|
private final List<PreProcessor> preProcessors;
|
||||||
private final Input input;
|
|
||||||
|
|
||||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
|
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||||
this.trainedModel = trainedModel;
|
this.trainedModel = trainedModel;
|
||||||
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||||
this.input = input;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -83,9 +78,6 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
true,
|
true,
|
||||||
PREPROCESSORS.getPreferredName(),
|
PREPROCESSORS.getPreferredName(),
|
||||||
preProcessors);
|
preProcessors);
|
||||||
if (input != null) {
|
|
||||||
builder.field(INPUT.getPreferredName(), input);
|
|
||||||
}
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -98,10 +90,6 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
return preProcessors;
|
return preProcessors;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Input getInput() {
|
|
||||||
return input;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return Strings.toString(this);
|
return Strings.toString(this);
|
||||||
|
@ -113,20 +101,18 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||||
return Objects.equals(trainedModel, that.trainedModel) &&
|
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||||
Objects.equals(preProcessors, that.preProcessors) &&
|
Objects.equals(preProcessors, that.preProcessors);
|
||||||
Objects.equals(input, that.input);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(trainedModel, preProcessors, input);
|
return Objects.hash(trainedModel, preProcessors);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
|
||||||
private List<PreProcessor> preProcessors;
|
private List<PreProcessor> preProcessors;
|
||||||
private TrainedModel trainedModel;
|
private TrainedModel trainedModel;
|
||||||
private Input input;
|
|
||||||
|
|
||||||
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||||
this.preProcessors = preProcessors;
|
this.preProcessors = preProcessors;
|
||||||
|
@ -138,71 +124,14 @@ public class TrainedModelDefinition implements ToXContentObject {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setInput(Input input) {
|
|
||||||
this.input = input;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||||
assert trainedModel.size() == 1;
|
assert trainedModel.size() == 1;
|
||||||
return setTrainedModel(trainedModel.get(0));
|
return setTrainedModel(trainedModel.get(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainedModelDefinition build() {
|
public TrainedModelDefinition build() {
|
||||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
|
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Input implements ToXContentObject {
|
|
||||||
|
|
||||||
public static final String NAME = "trained_mode_definition_input";
|
|
||||||
public static final ParseField FIELD_NAMES = new ParseField("field_names");
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
|
||||||
true,
|
|
||||||
a -> new Input((List<String>)a[0]));
|
|
||||||
static {
|
|
||||||
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Input fromXContent(XContentParser parser) throws IOException {
|
|
||||||
return PARSER.parse(parser, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
private final List<String> fieldNames;
|
|
||||||
|
|
||||||
public Input(List<String> fieldNames) {
|
|
||||||
this.fieldNames = fieldNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getFieldNames() {
|
|
||||||
return fieldNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
||||||
builder.startObject();
|
|
||||||
if (fieldNames != null) {
|
|
||||||
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
|
|
||||||
}
|
|
||||||
builder.endObject();
|
|
||||||
return builder;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) return true;
|
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
|
||||||
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
|
|
||||||
return Objects.equals(fieldNames, that.fieldNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Objects.hash(fieldNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class TrainedModelInput implements ToXContentObject {
|
||||||
|
|
||||||
|
public static final String NAME = "trained_model_config_input";
|
||||||
|
public static final ParseField FIELD_NAMES = new ParseField("field_names");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static final ConstructingObjectParser<TrainedModelInput, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||||
|
true,
|
||||||
|
a -> new TrainedModelInput((List<String>) a[0]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final List<String> fieldNames;
|
||||||
|
|
||||||
|
public TrainedModelInput(List<String> fieldNames) {
|
||||||
|
this.fieldNames = fieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
|
||||||
|
return PARSER.parse(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getFieldNames() {
|
||||||
|
return fieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
if (fieldNames != null) {
|
||||||
|
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
TrainedModelInput that = (TrainedModelInput) o;
|
||||||
|
return Objects.equals(fieldNames, that.fieldNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(fieldNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -63,7 +63,8 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
||||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||||
randomBoolean() ? null :
|
randomBoolean() ? null :
|
||||||
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
|
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
|
||||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||||
|
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -64,10 +64,7 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
|
||||||
TargetMeanEncodingTests.createRandom()))
|
TargetMeanEncodingTests.createRandom()))
|
||||||
.limit(numberOfProcessors)
|
.limit(numberOfProcessors)
|
||||||
.collect(Collectors.toList()))
|
.collect(Collectors.toList()))
|
||||||
.setTrainedModel(randomFrom(TreeTests.createRandom()))
|
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||||
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
|
|
||||||
.limit(randomLongBetween(1, 10))
|
|
||||||
.collect(Collectors.toList())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
|
public class TrainedModelInputTests extends AbstractXContentTestCase<TrainedModelInput> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelInput doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return TrainedModelInput.fromXContent(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> !field.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelInput createRandomInput() {
|
||||||
|
return new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10))
|
||||||
|
.limit(randomLongBetween(1, 10))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelInput createTestInstance() {
|
||||||
|
return createRandomInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -42,6 +42,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
public static final ParseField DEFINITION = new ParseField("definition");
|
public static final ParseField DEFINITION = new ParseField("definition");
|
||||||
public static final ParseField TAGS = new ParseField("tags");
|
public static final ParseField TAGS = new ParseField("tags");
|
||||||
public static final ParseField METADATA = new ParseField("metadata");
|
public static final ParseField METADATA = new ParseField("metadata");
|
||||||
|
public static final ParseField INPUT = new ParseField("input");
|
||||||
|
|
||||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||||
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
|
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||||
|
@ -61,10 +62,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
ObjectParser.ValueType.VALUE);
|
ObjectParser.ValueType.VALUE);
|
||||||
parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
|
||||||
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||||
parser.declareObject(TrainedModelConfig.Builder::setDefinition,
|
|
||||||
(p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
|
|
||||||
DEFINITION);
|
|
||||||
parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE);
|
parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE);
|
||||||
|
parser.declareObject(TrainedModelConfig.Builder::setInput,
|
||||||
|
(p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields),
|
||||||
|
INPUT);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,10 +80,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
private final Instant createTime;
|
private final Instant createTime;
|
||||||
private final List<String> tags;
|
private final List<String> tags;
|
||||||
private final Map<String, Object> metadata;
|
private final Map<String, Object> metadata;
|
||||||
|
private final TrainedModelInput input;
|
||||||
|
|
||||||
// TODO how to reference and store large models that will not be executed in Java???
|
|
||||||
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
|
|
||||||
// TODO Should this be lazily parsed when loading via the index???
|
|
||||||
private final TrainedModelDefinition definition;
|
private final TrainedModelDefinition definition;
|
||||||
|
|
||||||
TrainedModelConfig(String modelId,
|
TrainedModelConfig(String modelId,
|
||||||
|
@ -92,7 +91,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
Instant createTime,
|
Instant createTime,
|
||||||
TrainedModelDefinition definition,
|
TrainedModelDefinition definition,
|
||||||
List<String> tags,
|
List<String> tags,
|
||||||
Map<String, Object> metadata) {
|
Map<String, Object> metadata,
|
||||||
|
TrainedModelInput input) {
|
||||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||||
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
||||||
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
|
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
|
||||||
|
@ -101,6 +101,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
this.description = description;
|
this.description = description;
|
||||||
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
|
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
|
||||||
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
|
||||||
|
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainedModelConfig(StreamInput in) throws IOException {
|
public TrainedModelConfig(StreamInput in) throws IOException {
|
||||||
|
@ -112,6 +113,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
definition = in.readOptionalWriteable(TrainedModelDefinition::new);
|
definition = in.readOptionalWriteable(TrainedModelDefinition::new);
|
||||||
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
|
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
|
||||||
metadata = in.readMap();
|
metadata = in.readMap();
|
||||||
|
input = new TrainedModelInput(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getModelId() {
|
public String getModelId() {
|
||||||
|
@ -147,6 +149,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
return definition;
|
return definition;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TrainedModelInput getInput() {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new Builder();
|
return new Builder();
|
||||||
}
|
}
|
||||||
|
@ -161,6 +167,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
out.writeOptionalWriteable(definition);
|
out.writeOptionalWriteable(definition);
|
||||||
out.writeCollection(tags, StreamOutput::writeString);
|
out.writeCollection(tags, StreamOutput::writeString);
|
||||||
out.writeMap(metadata);
|
out.writeMap(metadata);
|
||||||
|
input.writeTo(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -173,7 +180,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
builder.field(DESCRIPTION.getPreferredName(), description);
|
builder.field(DESCRIPTION.getPreferredName(), description);
|
||||||
}
|
}
|
||||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||||
if (definition != null) {
|
|
||||||
|
// We don't store the definition in the same document as the configuration
|
||||||
|
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
||||||
builder.field(DEFINITION.getPreferredName(), definition);
|
builder.field(DEFINITION.getPreferredName(), definition);
|
||||||
}
|
}
|
||||||
builder.field(TAGS.getPreferredName(), tags);
|
builder.field(TAGS.getPreferredName(), tags);
|
||||||
|
@ -183,6 +192,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||||
}
|
}
|
||||||
|
builder.field(INPUT.getPreferredName(), input);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -204,6 +214,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
Objects.equals(createTime, that.createTime) &&
|
Objects.equals(createTime, that.createTime) &&
|
||||||
Objects.equals(definition, that.definition) &&
|
Objects.equals(definition, that.definition) &&
|
||||||
Objects.equals(tags, that.tags) &&
|
Objects.equals(tags, that.tags) &&
|
||||||
|
Objects.equals(input, that.input) &&
|
||||||
Objects.equals(metadata, that.metadata);
|
Objects.equals(metadata, that.metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,7 +227,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
definition,
|
definition,
|
||||||
description,
|
description,
|
||||||
tags,
|
tags,
|
||||||
metadata);
|
metadata,
|
||||||
|
input);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
@ -228,6 +240,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
private Instant createTime;
|
private Instant createTime;
|
||||||
private List<String> tags = Collections.emptyList();
|
private List<String> tags = Collections.emptyList();
|
||||||
private Map<String, Object> metadata;
|
private Map<String, Object> metadata;
|
||||||
|
private TrainedModelInput input;
|
||||||
private TrainedModelDefinition definition;
|
private TrainedModelDefinition definition;
|
||||||
|
|
||||||
public Builder setModelId(String modelId) {
|
public Builder setModelId(String modelId) {
|
||||||
|
@ -279,9 +292,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setInput(TrainedModelInput input) {
|
||||||
|
this.input = input;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO move to REST level instead of here in the builder
|
// TODO move to REST level instead of here in the builder
|
||||||
public void validate() {
|
public void validate() {
|
||||||
// We require a definition to be available until we support other means of supplying the definition
|
// We require a definition to be available here even though it will be stored in a different doc
|
||||||
ExceptionsHelper.requireNonNull(definition, DEFINITION);
|
ExceptionsHelper.requireNonNull(definition, DEFINITION);
|
||||||
ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||||
|
|
||||||
|
@ -320,7 +338,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
createTime == null ? Instant.now() : createTime,
|
createTime == null ? Instant.now() : createTime,
|
||||||
definition,
|
definition,
|
||||||
tags,
|
tags,
|
||||||
metadata);
|
metadata,
|
||||||
|
input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,16 +5,17 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.inference;
|
package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
||||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||||
|
@ -23,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrai
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -31,11 +33,10 @@ import java.util.Objects;
|
||||||
|
|
||||||
public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
public static final String NAME = "trained_mode_definition";
|
public static final String NAME = "trained_model_definition";
|
||||||
|
|
||||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||||
public static final ParseField INPUT = new ParseField("input");
|
|
||||||
|
|
||||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||||
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
|
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||||
|
@ -44,7 +45,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
|
private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
|
||||||
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
|
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
|
||||||
ignoreUnknownFields,
|
ignoreUnknownFields,
|
||||||
TrainedModelDefinition.Builder::new);
|
TrainedModelDefinition.Builder::builderForParser);
|
||||||
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
|
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
|
||||||
(p, c, n) -> ignoreUnknownFields ?
|
(p, c, n) -> ignoreUnknownFields ?
|
||||||
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
||||||
|
@ -57,7 +58,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
||||||
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
||||||
PREPROCESSORS);
|
PREPROCESSORS);
|
||||||
parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT);
|
parser.declareString(TrainedModelDefinition.Builder::setModelId, TrainedModelConfig.MODEL_ID);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,27 +66,31 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String docId(String modelId) {
|
||||||
|
return NAME + "-" + modelId;
|
||||||
|
}
|
||||||
|
|
||||||
private final TrainedModel trainedModel;
|
private final TrainedModel trainedModel;
|
||||||
private final List<PreProcessor> preProcessors;
|
private final List<PreProcessor> preProcessors;
|
||||||
private final Input input;
|
private final String modelId;
|
||||||
|
|
||||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
|
private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, @Nullable String modelId) {
|
||||||
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
|
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
|
||||||
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||||
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
|
this.modelId = modelId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainedModelDefinition(StreamInput in) throws IOException {
|
public TrainedModelDefinition(StreamInput in) throws IOException {
|
||||||
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
|
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
|
||||||
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
|
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
|
||||||
this.input = new Input(in);
|
this.modelId = in.readOptionalString();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeNamedWriteable(trainedModel);
|
out.writeNamedWriteable(trainedModel);
|
||||||
out.writeNamedWriteableList(preProcessors);
|
out.writeNamedWriteableList(preProcessors);
|
||||||
input.writeTo(out);
|
out.writeOptionalString(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -101,7 +106,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
true,
|
true,
|
||||||
PREPROCESSORS.getPreferredName(),
|
PREPROCESSORS.getPreferredName(),
|
||||||
preProcessors);
|
preProcessors);
|
||||||
builder.field(INPUT.getPreferredName(), input);
|
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
|
||||||
|
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||||
|
assert modelId != null;
|
||||||
|
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -114,10 +123,6 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
return preProcessors;
|
return preProcessors;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Input getInput() {
|
|
||||||
return input;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return Strings.toString(this);
|
return Strings.toString(this);
|
||||||
|
@ -129,21 +134,21 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||||
return Objects.equals(trainedModel, that.trainedModel) &&
|
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||||
Objects.equals(input, that.input) &&
|
Objects.equals(preProcessors, that.preProcessors) &&
|
||||||
Objects.equals(preProcessors, that.preProcessors);
|
Objects.equals(modelId, that.modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(trainedModel, input, preProcessors);
|
return Objects.hash(trainedModel, preProcessors, modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
|
||||||
private List<PreProcessor> preProcessors;
|
private List<PreProcessor> preProcessors;
|
||||||
private TrainedModel trainedModel;
|
private TrainedModel trainedModel;
|
||||||
|
private String modelId;
|
||||||
private boolean processorsInOrder;
|
private boolean processorsInOrder;
|
||||||
private Input input;
|
|
||||||
|
|
||||||
private static Builder builderForParser() {
|
private static Builder builderForParser() {
|
||||||
return new Builder(false);
|
return new Builder(false);
|
||||||
|
@ -167,8 +172,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setInput(Input input) {
|
public Builder setModelId(String modelId) {
|
||||||
this.input = input;
|
this.modelId = modelId;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,71 +193,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
|
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
|
||||||
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
|
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
|
||||||
}
|
}
|
||||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
|
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.modelId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Input implements ToXContentObject, Writeable {
|
|
||||||
|
|
||||||
public static final String NAME = "trained_mode_definition_input";
|
|
||||||
public static final ParseField FIELD_NAMES = new ParseField("field_names");
|
|
||||||
|
|
||||||
public static final ConstructingObjectParser<Input, Void> LENIENT_PARSER = createParser(true);
|
|
||||||
public static final ConstructingObjectParser<Input, Void> STRICT_PARSER = createParser(false);
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
private static ConstructingObjectParser<Input, Void> createParser(boolean ignoreUnknownFields) {
|
|
||||||
ConstructingObjectParser<Input, Void> parser = new ConstructingObjectParser<>(NAME,
|
|
||||||
ignoreUnknownFields,
|
|
||||||
a -> new Input((List<String>)a[0]));
|
|
||||||
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
|
|
||||||
return parser;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
|
||||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
private final List<String> fieldNames;
|
|
||||||
|
|
||||||
public Input(List<String> fieldNames) {
|
|
||||||
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
|
|
||||||
}
|
|
||||||
|
|
||||||
public Input(StreamInput in) throws IOException {
|
|
||||||
this.fieldNames = Collections.unmodifiableList(in.readStringList());
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getFieldNames() {
|
|
||||||
return fieldNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
|
||||||
out.writeStringCollection(fieldNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
||||||
builder.startObject();
|
|
||||||
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
|
|
||||||
builder.endObject();
|
|
||||||
return builder;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) return true;
|
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
|
||||||
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
|
|
||||||
return Objects.equals(fieldNames, that.fieldNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Objects.hash(fieldNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
|
||||||
|
public class TrainedModelInput implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
|
public static final String NAME = "trained_model_config_input";
|
||||||
|
public static final ParseField FIELD_NAMES = new ParseField("field_names");
|
||||||
|
|
||||||
|
public static final ConstructingObjectParser<TrainedModelInput, Void> LENIENT_PARSER = createParser(true);
|
||||||
|
public static final ConstructingObjectParser<TrainedModelInput, Void> STRICT_PARSER = createParser(false);
|
||||||
|
private final List<String> fieldNames;
|
||||||
|
|
||||||
|
public TrainedModelInput(List<String> fieldNames) {
|
||||||
|
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModelInput(StreamInput in) throws IOException {
|
||||||
|
this.fieldNames = Collections.unmodifiableList(in.readStringList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static ConstructingObjectParser<TrainedModelInput, Void> createParser(boolean ignoreUnknownFields) {
|
||||||
|
ConstructingObjectParser<TrainedModelInput, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||||
|
ignoreUnknownFields,
|
||||||
|
a -> new TrainedModelInput((List<String>) a[0]));
|
||||||
|
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
|
||||||
|
return parser;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelInput fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
||||||
|
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getFieldNames() {
|
||||||
|
return fieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeStringCollection(fieldNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
TrainedModelInput that = (TrainedModelInput) o;
|
||||||
|
return Objects.equals(fieldNames, that.fieldNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(fieldNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -82,9 +82,8 @@ public final class Messages {
|
||||||
|
|
||||||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
||||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
||||||
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
|
|
||||||
"Failed to serialize the trained model [{0}] with version [{1}] for storage";
|
|
||||||
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
||||||
|
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
||||||
|
|
||||||
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
|
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
|
||||||
public static final String JOB_AUDIT_CREATED = "Job created";
|
public static final String JOB_AUDIT_CREATED = "Job created";
|
||||||
|
|
|
@ -7,15 +7,20 @@ package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
import org.elasticsearch.ElasticsearchException;
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.search.SearchModule;
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -28,7 +33,9 @@ import java.util.function.Predicate;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
|
||||||
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
|
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
|
||||||
|
|
||||||
|
@ -63,9 +70,10 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
Version.CURRENT,
|
Version.CURRENT,
|
||||||
randomBoolean() ? null : randomAlphaOfLength(100),
|
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
null, // is not parsed so should not be provided
|
||||||
tags,
|
tags,
|
||||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||||
|
TrainedModelInputTests.createRandomInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -88,6 +96,28 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
return new NamedWriteableRegistry(entries);
|
return new NamedWriteableRegistry(entries);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testToXContentWithParams() throws IOException {
|
||||||
|
TrainedModelConfig config = new TrainedModelConfig(
|
||||||
|
randomAlphaOfLength(10),
|
||||||
|
randomAlphaOfLength(10),
|
||||||
|
Version.CURRENT,
|
||||||
|
randomBoolean() ? null : randomAlphaOfLength(100),
|
||||||
|
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||||
|
TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(),
|
||||||
|
Collections.emptyList(),
|
||||||
|
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||||
|
TrainedModelInputTests.createRandomInput());
|
||||||
|
|
||||||
|
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
||||||
|
assertThat(reference.utf8ToString(), containsString("definition"));
|
||||||
|
|
||||||
|
reference = XContentHelper.toXContent(config,
|
||||||
|
XContentType.JSON,
|
||||||
|
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
|
||||||
|
false);
|
||||||
|
assertThat(reference.utf8ToString(), not(containsString("definition")));
|
||||||
|
}
|
||||||
|
|
||||||
public void testValidateWithNullDefinition() {
|
public void testValidateWithNullDefinition() {
|
||||||
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
|
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
|
||||||
assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
|
assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
|
||||||
|
@ -97,7 +127,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
String modelId = "InvalidID-";
|
String modelId = "InvalidID-";
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
||||||
}
|
}
|
||||||
|
@ -106,7 +136,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
|
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
||||||
}
|
}
|
||||||
|
@ -115,21 +145,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
String modelId = "simplemodel";
|
String modelId = "simplemodel";
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||||
.setCreateTime(Instant.now())
|
.setCreateTime(Instant.now())
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
|
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
|
||||||
|
|
||||||
ex = expectThrows(ElasticsearchException.class,
|
ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||||
.setVersion(Version.CURRENT)
|
.setVersion(Version.CURRENT)
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
|
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
|
||||||
|
|
||||||
ex = expectThrows(ElasticsearchException.class,
|
ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||||
.setCreatedBy("ml_user")
|
.setCreatedBy("ml_user")
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
||||||
|
|
|
@ -58,9 +58,10 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||||
return field -> !field.isEmpty();
|
return field -> !field.isEmpty();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) {
|
||||||
int numberOfProcessors = randomIntBetween(1, 10);
|
int numberOfProcessors = randomIntBetween(1, 10);
|
||||||
return new TrainedModelDefinition.Builder()
|
return new TrainedModelDefinition.Builder()
|
||||||
|
.setModelId(modelId)
|
||||||
.setPreProcessors(
|
.setPreProcessors(
|
||||||
randomBoolean() ? null :
|
randomBoolean() ? null :
|
||||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
|
@ -68,22 +69,11 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||||
TargetMeanEncodingTests.createRandom()))
|
TargetMeanEncodingTests.createRandom()))
|
||||||
.limit(numberOfProcessors)
|
.limit(numberOfProcessors)
|
||||||
.collect(Collectors.toList()))
|
.collect(Collectors.toList()))
|
||||||
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
|
|
||||||
.limit(randomLongBetween(1, 10))
|
|
||||||
.collect(Collectors.toList())))
|
|
||||||
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final String ENSEMBLE_MODEL = "" +
|
private static final String ENSEMBLE_MODEL = "" +
|
||||||
"{\n" +
|
"{\n" +
|
||||||
" \"input\": {\n" +
|
|
||||||
" \"field_names\": [\n" +
|
|
||||||
" \"col1\",\n" +
|
|
||||||
" \"col2\",\n" +
|
|
||||||
" \"col3\",\n" +
|
|
||||||
" \"col4\"\n" +
|
|
||||||
" ]\n" +
|
|
||||||
" },\n" +
|
|
||||||
" \"preprocessors\": [\n" +
|
" \"preprocessors\": [\n" +
|
||||||
" {\n" +
|
" {\n" +
|
||||||
" \"one_hot_encoding\": {\n" +
|
" \"one_hot_encoding\": {\n" +
|
||||||
|
@ -203,14 +193,6 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||||
"}";
|
"}";
|
||||||
private static final String TREE_MODEL = "" +
|
private static final String TREE_MODEL = "" +
|
||||||
"{\n" +
|
"{\n" +
|
||||||
" \"input\": {\n" +
|
|
||||||
" \"field_names\": [\n" +
|
|
||||||
" \"col1\",\n" +
|
|
||||||
" \"col2\",\n" +
|
|
||||||
" \"col3\",\n" +
|
|
||||||
" \"col4\"\n" +
|
|
||||||
" ]\n" +
|
|
||||||
" },\n" +
|
|
||||||
" \"preprocessors\": [\n" +
|
" \"preprocessors\": [\n" +
|
||||||
" {\n" +
|
" {\n" +
|
||||||
" \"one_hot_encoding\": {\n" +
|
" \"one_hot_encoding\": {\n" +
|
||||||
|
@ -293,7 +275,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected TrainedModelDefinition createTestInstance() {
|
protected TrainedModelDefinition createTestInstance() {
|
||||||
return createRandomBuilder().build();
|
return createRandomBuilder(null).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
|
public class TrainedModelInputTests extends AbstractSerializingTestCase<TrainedModelInput> {
|
||||||
|
|
||||||
|
private boolean lenient;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void chooseStrictOrLenient() {
|
||||||
|
lenient = randomBoolean();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelInput doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return TrainedModelInput.fromXContent(parser, lenient);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return lenient;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> !field.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelInput createRandomInput() {
|
||||||
|
return new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10))
|
||||||
|
.limit(randomInt(10))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelInput createTestInstance() {
|
||||||
|
return createRandomInput();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<TrainedModelInput> instanceReader() {
|
||||||
|
return TrainedModelInput::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -366,7 +366,7 @@ public class AnalyticsProcessManager {
|
||||||
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
|
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
|
||||||
dataExtractorFactory.newExtractor(true));
|
dataExtractorFactory.newExtractor(true));
|
||||||
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
|
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
|
||||||
trainedModelProvider, auditor);
|
trainedModelProvider, auditor, dataExtractor.getFieldNames());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
|
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||||
|
@ -26,6 +27,7 @@ import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
@ -41,18 +43,21 @@ public class AnalyticsResultProcessor {
|
||||||
private final ProgressTracker progressTracker;
|
private final ProgressTracker progressTracker;
|
||||||
private final TrainedModelProvider trainedModelProvider;
|
private final TrainedModelProvider trainedModelProvider;
|
||||||
private final DataFrameAnalyticsAuditor auditor;
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
|
private final List<String> fieldNames;
|
||||||
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
||||||
private volatile String failure;
|
private volatile String failure;
|
||||||
|
|
||||||
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
||||||
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
|
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
|
||||||
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor) {
|
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor,
|
||||||
|
List<String> fieldNames) {
|
||||||
this.analytics = Objects.requireNonNull(analytics);
|
this.analytics = Objects.requireNonNull(analytics);
|
||||||
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
||||||
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
|
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
|
||||||
this.progressTracker = Objects.requireNonNull(progressTracker);
|
this.progressTracker = Objects.requireNonNull(progressTracker);
|
||||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||||
this.auditor = Objects.requireNonNull(auditor);
|
this.auditor = Objects.requireNonNull(auditor);
|
||||||
|
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
|
@ -111,13 +116,13 @@ public class AnalyticsResultProcessor {
|
||||||
if (progressPercent != null) {
|
if (progressPercent != null) {
|
||||||
progressTracker.analyzingPercent.set(progressPercent);
|
progressTracker.analyzingPercent.set(progressPercent);
|
||||||
}
|
}
|
||||||
TrainedModelDefinition inferenceModel = result.getInferenceModel();
|
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
|
||||||
if (inferenceModel != null) {
|
if (inferenceModelBuilder != null) {
|
||||||
createAndIndexInferenceModel(inferenceModel);
|
createAndIndexInferenceModel(inferenceModelBuilder);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) {
|
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
|
||||||
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
|
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
|
||||||
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
|
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
|
||||||
|
|
||||||
|
@ -131,10 +136,12 @@ public class AnalyticsResultProcessor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition inferenceModel) {
|
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
|
||||||
Instant createTime = Instant.now();
|
Instant createTime = Instant.now();
|
||||||
|
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||||
|
TrainedModelDefinition definition = inferenceModel.setModelId(modelId).build();
|
||||||
return TrainedModelConfig.builder()
|
return TrainedModelConfig.builder()
|
||||||
.setModelId(analytics.getId() + "-" + createTime.toEpochMilli())
|
.setModelId(modelId)
|
||||||
.setCreatedBy("data-frame-analytics")
|
.setCreatedBy("data-frame-analytics")
|
||||||
.setVersion(Version.CURRENT)
|
.setVersion(Version.CURRENT)
|
||||||
.setCreateTime(createTime)
|
.setCreateTime(createTime)
|
||||||
|
@ -142,7 +149,8 @@ public class AnalyticsResultProcessor {
|
||||||
.setDescription(analytics.getDescription())
|
.setDescription(analytics.getDescription())
|
||||||
.setMetadata(Collections.singletonMap("analytics_config",
|
.setMetadata(Collections.singletonMap("analytics_config",
|
||||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||||
.setDefinition(inferenceModel)
|
.setDefinition(definition)
|
||||||
|
.setInput(new TrainedModelInput(fieldNames))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,23 +24,25 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
public static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
public static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
||||||
|
|
||||||
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
|
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
|
||||||
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition) a[2]));
|
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition.Builder) a[2]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
||||||
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
|
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
|
||||||
PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.STRICT_PARSER.apply(p, null).build(),
|
// TODO change back to STRICT_PARSER once native side is aligned
|
||||||
INFERENCE_MODEL);
|
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final RowResults rowResults;
|
private final RowResults rowResults;
|
||||||
private final Integer progressPercent;
|
private final Integer progressPercent;
|
||||||
|
private final TrainedModelDefinition.Builder inferenceModelBuilder;
|
||||||
private final TrainedModelDefinition inferenceModel;
|
private final TrainedModelDefinition inferenceModel;
|
||||||
|
|
||||||
public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition inferenceModel) {
|
public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition.Builder inferenceModelBuilder) {
|
||||||
this.rowResults = rowResults;
|
this.rowResults = rowResults;
|
||||||
this.progressPercent = progressPercent;
|
this.progressPercent = progressPercent;
|
||||||
this.inferenceModel = inferenceModel;
|
this.inferenceModelBuilder = inferenceModelBuilder;
|
||||||
|
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
public RowResults getRowResults() {
|
public RowResults getRowResults() {
|
||||||
|
@ -51,8 +53,8 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
return progressPercent;
|
return progressPercent;
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainedModelDefinition getInferenceModel() {
|
public TrainedModelDefinition.Builder getInferenceModelBuilder() {
|
||||||
return inferenceModel;
|
return inferenceModelBuilder;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -86,6 +86,9 @@ public final class InferenceInternalIndex {
|
||||||
.startObject(TrainedModelConfig.CREATED_BY.getPreferredName())
|
.startObject(TrainedModelConfig.CREATED_BY.getPreferredName())
|
||||||
.field(TYPE, KEYWORD)
|
.field(TYPE, KEYWORD)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
.startObject(TrainedModelConfig.INPUT.getPreferredName())
|
||||||
|
.field(ENABLED, false)
|
||||||
|
.endObject()
|
||||||
.startObject(TrainedModelConfig.VERSION.getPreferredName())
|
.startObject(TrainedModelConfig.VERSION.getPreferredName())
|
||||||
.field(TYPE, KEYWORD)
|
.field(TYPE, KEYWORD)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
@ -95,9 +98,6 @@ public final class InferenceInternalIndex {
|
||||||
.startObject(TrainedModelConfig.CREATE_TIME.getPreferredName())
|
.startObject(TrainedModelConfig.CREATE_TIME.getPreferredName())
|
||||||
.field(TYPE, DATE)
|
.field(TYPE, DATE)
|
||||||
.endObject()
|
.endObject()
|
||||||
.startObject(TrainedModelConfig.DEFINITION.getPreferredName())
|
|
||||||
.field(ENABLED, false)
|
|
||||||
.endObject()
|
|
||||||
.startObject(TrainedModelConfig.TAGS.getPreferredName())
|
.startObject(TrainedModelConfig.TAGS.getPreferredName())
|
||||||
.field(TYPE, KEYWORD)
|
.field(TYPE, KEYWORD)
|
||||||
.endObject()
|
.endObject()
|
||||||
|
|
|
@ -8,32 +8,38 @@ package org.elasticsearch.xpack.ml.inference.persistence;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.elasticsearch.ElasticsearchParseException;
|
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.ResourceAlreadyExistsException;
|
import org.elasticsearch.ResourceAlreadyExistsException;
|
||||||
import org.elasticsearch.ResourceNotFoundException;
|
import org.elasticsearch.ResourceNotFoundException;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.action.DocWriteRequest;
|
import org.elasticsearch.action.DocWriteRequest;
|
||||||
import org.elasticsearch.action.index.IndexAction;
|
import org.elasticsearch.action.bulk.BulkAction;
|
||||||
|
import org.elasticsearch.action.bulk.BulkRequest;
|
||||||
|
import org.elasticsearch.action.bulk.BulkResponse;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.SearchAction;
|
import org.elasticsearch.action.search.MultiSearchAction;
|
||||||
import org.elasticsearch.action.search.SearchRequest;
|
import org.elasticsearch.action.search.MultiSearchRequestBuilder;
|
||||||
|
import org.elasticsearch.action.search.MultiSearchResponse;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
|
import org.elasticsearch.common.CheckedBiFunction;
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.index.engine.VersionConflictEngineException;
|
import org.elasticsearch.index.engine.VersionConflictEngineException;
|
||||||
|
import org.elasticsearch.index.mapper.MapperService;
|
||||||
import org.elasticsearch.index.query.QueryBuilder;
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.search.sort.SortOrder;
|
import org.elasticsearch.search.sort.SortOrder;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
@ -51,6 +57,8 @@ public class TrainedModelProvider {
|
||||||
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
|
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
|
||||||
private final Client client;
|
private final Client client;
|
||||||
private final NamedXContentRegistry xContentRegistry;
|
private final NamedXContentRegistry xContentRegistry;
|
||||||
|
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
|
||||||
|
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
|
||||||
|
|
||||||
public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) {
|
public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) {
|
||||||
this.client = client;
|
this.client = client;
|
||||||
|
@ -58,76 +66,178 @@ public class TrainedModelProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
|
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
|
||||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
|
||||||
XContentBuilder source = trainedModelConfig.toXContent(builder,
|
|
||||||
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME)
|
if (trainedModelConfig.getDefinition() == null) {
|
||||||
.opType(DocWriteRequest.OpType.CREATE)
|
listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] is required",
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
trainedModelConfig.getModelId(),
|
||||||
.id(trainedModelConfig.getModelId())
|
TrainedModelConfig.DEFINITION.getPreferredName()));
|
||||||
.source(source);
|
return;
|
||||||
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest,
|
|
||||||
ActionListener.wrap(
|
|
||||||
r -> listener.onResponse(true),
|
|
||||||
e -> {
|
|
||||||
logger.error(new ParameterizedMessage(
|
|
||||||
"[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e);
|
|
||||||
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
|
||||||
listener.onFailure(new ResourceAlreadyExistsException(
|
|
||||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
|
||||||
} else {
|
|
||||||
listener.onFailure(
|
|
||||||
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
|
||||||
RestStatus.INTERNAL_SERVER_ERROR,
|
|
||||||
e,
|
|
||||||
trainedModelConfig.getModelId()));
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
} catch (IOException e) {
|
|
||||||
// not expected to happen but for the sake of completeness
|
|
||||||
listener.onFailure(new ElasticsearchParseException(
|
|
||||||
Messages.getMessage(Messages.INFERENCE_FAILED_TO_SERIALIZE_MODEL, trainedModelConfig.getModelId()),
|
|
||||||
e));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||||
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||||
|
.add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig))
|
||||||
|
.add(createRequest(TrainedModelDefinition.docId(trainedModelConfig.getModelId()), trainedModelConfig.getDefinition()))
|
||||||
|
.request();
|
||||||
|
|
||||||
|
ActionListener<Boolean> wrappedListener = ActionListener.wrap(
|
||||||
|
listener::onResponse,
|
||||||
|
e -> {
|
||||||
|
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||||
|
listener.onFailure(new ResourceAlreadyExistsException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||||
|
} else {
|
||||||
|
listener.onFailure(
|
||||||
|
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
e,
|
||||||
|
trainedModelConfig.getModelId()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
ActionListener<BulkResponse> bulkResponseActionListener = ActionListener.wrap(
|
||||||
|
r -> {
|
||||||
|
assert r.getItems().length == 2;
|
||||||
|
if (r.getItems()[0].isFailed()) {
|
||||||
|
logger.error(new ParameterizedMessage(
|
||||||
|
"[{}] failed to store trained model config for inference",
|
||||||
|
trainedModelConfig.getModelId()),
|
||||||
|
r.getItems()[0].getFailure().getCause());
|
||||||
|
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (r.getItems()[1].isFailed()) {
|
||||||
|
logger.error(new ParameterizedMessage(
|
||||||
|
"[{}] failed to store trained model definition for inference",
|
||||||
|
trainedModelConfig.getModelId()),
|
||||||
|
r.getItems()[1].getFailure().getCause());
|
||||||
|
wrappedListener.onFailure(r.getItems()[1].getFailure().getCause());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
wrappedListener.onResponse(true);
|
||||||
|
},
|
||||||
|
wrappedListener::onFailure
|
||||||
|
);
|
||||||
|
|
||||||
|
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void getTrainedModel(String modelId, ActionListener<TrainedModelConfig> listener) {
|
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
||||||
|
|
||||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
||||||
.idsQuery()
|
.idsQuery()
|
||||||
.addIds(modelId));
|
.addIds(modelId));
|
||||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
MultiSearchRequestBuilder multiSearchRequestBuilder = client.prepareMultiSearch()
|
||||||
.setQuery(queryBuilder)
|
.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||||
// use sort to get the last
|
.setQuery(queryBuilder)
|
||||||
.addSort("_index", SortOrder.DESC)
|
// use sort to get the last
|
||||||
.setSize(1)
|
.addSort("_index", SortOrder.DESC)
|
||||||
.request();
|
.setSize(1)
|
||||||
|
.request());
|
||||||
|
|
||||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest,
|
if (includeDefinition) {
|
||||||
ActionListener.wrap(
|
multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||||
searchResponse -> {
|
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
||||||
if (searchResponse.getHits().getHits().length == 0) {
|
.idsQuery()
|
||||||
|
.addIds(TrainedModelDefinition.docId(modelId))))
|
||||||
|
// use sort to get the last
|
||||||
|
.addSort("_index", SortOrder.DESC)
|
||||||
|
.setSize(1)
|
||||||
|
.request());
|
||||||
|
}
|
||||||
|
|
||||||
|
ActionListener<MultiSearchResponse> multiSearchResponseActionListener = ActionListener.wrap(
|
||||||
|
multiSearchResponse -> {
|
||||||
|
TrainedModelConfig.Builder builder;
|
||||||
|
TrainedModelDefinition definition;
|
||||||
|
try {
|
||||||
|
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
|
||||||
|
} catch (ResourceNotFoundException ex) {
|
||||||
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||||
|
return;
|
||||||
|
} catch (Exception ex) {
|
||||||
|
listener.onFailure(ex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (includeDefinition) {
|
||||||
|
try {
|
||||||
|
definition = handleSearchItem(multiSearchResponse.getResponses()[1],
|
||||||
|
modelId,
|
||||||
|
this::parseModelDefinitionDocLenientlyFromSource);
|
||||||
|
builder.setDefinition(definition);
|
||||||
|
} catch (ResourceNotFoundException ex) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||||
|
return;
|
||||||
|
} catch (Exception ex) {
|
||||||
|
listener.onFailure(ex);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef();
|
}
|
||||||
parseInferenceDocLenientlyFromSource(source, modelId, listener);
|
listener.onResponse(builder.build());
|
||||||
},
|
},
|
||||||
listener::onFailure));
|
listener::onFailure
|
||||||
|
);
|
||||||
|
|
||||||
|
executeAsyncWithOrigin(client,
|
||||||
|
ML_ORIGIN,
|
||||||
|
MultiSearchAction.INSTANCE,
|
||||||
|
multiSearchRequestBuilder.request(),
|
||||||
|
multiSearchResponseActionListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void parseInferenceDocLenientlyFromSource(BytesReference source,
|
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
||||||
String modelId,
|
String resourceId,
|
||||||
ActionListener<TrainedModelConfig> modelListener) {
|
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
|
||||||
|
if (item.isFailure()) {
|
||||||
|
throw item.getFailure();
|
||||||
|
}
|
||||||
|
if (item.getResponse().getHits().getHits().length == 0) {
|
||||||
|
throw new ResourceNotFoundException(resourceId);
|
||||||
|
}
|
||||||
|
return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
|
||||||
try (InputStream stream = source.streamInput();
|
try (InputStream stream = source.streamInput();
|
||||||
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||||
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
|
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
|
||||||
modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build());
|
return TrainedModelConfig.fromXContent(parser, true);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
|
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
|
||||||
modelListener.onFailure(e);
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
|
||||||
|
try (InputStream stream = source.streamInput();
|
||||||
|
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
|
||||||
|
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
|
||||||
|
return TrainedModelDefinition.fromXContent(parser, true).build();
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private IndexRequest createRequest(String docId, ToXContentObject body) {
|
||||||
|
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||||
|
XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
|
||||||
|
|
||||||
|
return new IndexRequest()
|
||||||
|
.opType(DocWriteRequest.OpType.CREATE)
|
||||||
|
.id(docId)
|
||||||
|
.source(source);
|
||||||
|
} catch (IOException ex) {
|
||||||
|
// This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again
|
||||||
|
// that is not the users fault. We did something wrong and should throw.
|
||||||
|
throw ExceptionsHelper.serverError(
|
||||||
|
new ParameterizedMessage("Unexpected serialization exception for [{}]", docId).getFormattedMessage(),
|
||||||
|
ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,9 +126,10 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
return null;
|
return null;
|
||||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||||
|
|
||||||
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
|
List<String> expectedFieldNames = Arrays.asList("foo", "bar", "baz");
|
||||||
|
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(JOB_ID);
|
||||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
|
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
|
||||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
AnalyticsResultProcessor resultProcessor = createResultProcessor(expectedFieldNames);
|
||||||
|
|
||||||
resultProcessor.process(process);
|
resultProcessor.process(process);
|
||||||
resultProcessor.awaitForCompletion();
|
resultProcessor.awaitForCompletion();
|
||||||
|
@ -142,7 +143,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));
|
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));
|
||||||
assertThat(storedModel.getTags(), contains(JOB_ID));
|
assertThat(storedModel.getTags(), contains(JOB_ID));
|
||||||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||||
assertThat(storedModel.getDefinition(), equalTo(inferenceModel));
|
assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build()));
|
||||||
|
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
|
||||||
Map<String, Object> metadata = storedModel.getMetadata();
|
Map<String, Object> metadata = storedModel.getMetadata();
|
||||||
assertThat(metadata.size(), equalTo(1));
|
assertThat(metadata.size(), equalTo(1));
|
||||||
assertThat(metadata, hasKey("analytics_config"));
|
assertThat(metadata, hasKey("analytics_config"));
|
||||||
|
@ -166,7 +168,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
return null;
|
return null;
|
||||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||||
|
|
||||||
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
|
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder("failed_model");
|
||||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
|
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
|
||||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||||
|
|
||||||
|
@ -192,7 +194,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private AnalyticsResultProcessor createResultProcessor() {
|
private AnalyticsResultProcessor createResultProcessor() {
|
||||||
|
return createResultProcessor(Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
|
||||||
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider,
|
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider,
|
||||||
auditor);
|
auditor, fieldNames);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -33,7 +32,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
||||||
protected AnalyticsResult createTestInstance() {
|
protected AnalyticsResult createTestInstance() {
|
||||||
RowResults rowResults = null;
|
RowResults rowResults = null;
|
||||||
Integer progressPercent = null;
|
Integer progressPercent = null;
|
||||||
TrainedModelDefinition inferenceModel = null;
|
TrainedModelDefinition.Builder inferenceModel = null;
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
rowResults = RowResultsTests.createRandom();
|
rowResults = RowResultsTests.createRandom();
|
||||||
}
|
}
|
||||||
|
@ -41,13 +40,13 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
||||||
progressPercent = randomIntBetween(0, 100);
|
progressPercent = randomIntBetween(0, 100);
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build();
|
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(null);
|
||||||
}
|
}
|
||||||
return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
|
return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected AnalyticsResult doParseInstance(XContentParser parser) throws IOException {
|
protected AnalyticsResult doParseInstance(XContentParser parser) {
|
||||||
return AnalyticsResult.PARSER.apply(parser, null);
|
return AnalyticsResult.PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,17 @@
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.action.delete.DeleteRequest;
|
||||||
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.search.SearchModule;
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
@ -75,29 +80,75 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder);
|
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||||
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||||
assertThat(getConfigHolder.get(), equalTo(config));
|
assertThat(getConfigHolder.get(), equalTo(config));
|
||||||
|
assertThat(getConfigHolder.get().getDefinition(), is(not(nullValue())));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGetTrainedModelConfigWithoutDefinition() throws Exception {
|
||||||
|
String modelId = "test-get-trained-model-config-no-definition";
|
||||||
|
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
|
||||||
|
TrainedModelConfig config = configBuilder.build();
|
||||||
|
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||||
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||||
|
|
||||||
|
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||||
|
assertThat(putConfigHolder.get(), is(true));
|
||||||
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
|
|
||||||
|
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder);
|
||||||
|
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||||
|
assertThat(getConfigHolder.get(),
|
||||||
|
equalTo(configBuilder.setCreateTime(config.getCreateTime()).setDefinition((TrainedModelDefinition) null).build()));
|
||||||
|
assertThat(getConfigHolder.get().getDefinition(), is(nullValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetMissingTrainingModelConfig() throws Exception {
|
public void testGetMissingTrainingModelConfig() throws Exception {
|
||||||
String modelId = "test-get-missing-trained-model-config";
|
String modelId = "test-get-missing-trained-model-config";
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder);
|
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
assertThat(exceptionHolder.get().getMessage(),
|
assertThat(exceptionHolder.get().getMessage(),
|
||||||
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static TrainedModelConfig buildTrainedModelConfig(String modelId) {
|
public void testGetMissingTrainingModelConfigDefinition() throws Exception {
|
||||||
|
String modelId = "test-get-missing-trained-model-config-definition";
|
||||||
|
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).build();
|
||||||
|
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||||
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||||
|
|
||||||
|
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||||
|
assertThat(putConfigHolder.get(), is(true));
|
||||||
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
|
||||||
|
client().delete(new DeleteRequest(InferenceIndexConstants.LATEST_INDEX_NAME)
|
||||||
|
.id(TrainedModelDefinition.docId(config.getModelId()))
|
||||||
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE))
|
||||||
|
.actionGet();
|
||||||
|
|
||||||
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
|
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||||
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
|
assertThat(exceptionHolder.get().getMessage(),
|
||||||
|
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||||
return TrainedModelConfig.builder()
|
return TrainedModelConfig.builder()
|
||||||
.setCreatedBy("ml_test")
|
.setCreatedBy("ml_test")
|
||||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
|
||||||
.setDescription("trained model config for test")
|
.setDescription("trained model config for test")
|
||||||
.setModelId(modelId)
|
.setModelId(modelId)
|
||||||
.setVersion(Version.CURRENT)
|
.setVersion(Version.CURRENT)
|
||||||
.build();
|
.setInput(TrainedModelInputTests.createRandomInput());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TrainedModelConfig buildTrainedModelConfig(String modelId) {
|
||||||
|
return buildTrainedModelConfigBuilder(modelId).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue