[7.x] [ML][Inference] separating definition and config object storage (#48651) (#48695)

* [ML][Inference] separating definition and config object storage (#48651)

This separates out the `definition` object from being stored within the configuration object in the index. 

This allows us to gather the config object without decompressing a potentially large definition.

Additionally, `input` is moved to the TrainedModelConfig object and out of the definition. This is so the trained input fields are accessible outside the potentially large model definition.
This commit is contained in:
Benjamin Trent 2019-10-30 13:27:29 -04:00 committed by GitHub
parent 0476f014bc
commit c9ead80c31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 678 additions and 297 deletions

View File

@ -46,6 +46,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField DEFINITION = new ParseField("definition"); public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField TAGS = new ParseField("tags"); public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true, true,
@ -64,6 +65,7 @@ public class TrainedModelConfig implements ToXContentObject {
DEFINITION); DEFINITION);
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
} }
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@ -78,6 +80,7 @@ public class TrainedModelConfig implements ToXContentObject {
private final TrainedModelDefinition definition; private final TrainedModelDefinition definition;
private final List<String> tags; private final List<String> tags;
private final Map<String, Object> metadata; private final Map<String, Object> metadata;
private final TrainedModelInput input;
TrainedModelConfig(String modelId, TrainedModelConfig(String modelId,
String createdBy, String createdBy,
@ -86,7 +89,8 @@ public class TrainedModelConfig implements ToXContentObject {
Instant createTime, Instant createTime,
TrainedModelDefinition definition, TrainedModelDefinition definition,
List<String> tags, List<String> tags,
Map<String, Object> metadata) { Map<String, Object> metadata,
TrainedModelInput input) {
this.modelId = modelId; this.modelId = modelId;
this.createdBy = createdBy; this.createdBy = createdBy;
this.version = version; this.version = version;
@ -95,6 +99,7 @@ public class TrainedModelConfig implements ToXContentObject {
this.description = description; this.description = description;
this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = input;
} }
public String getModelId() { public String getModelId() {
@ -129,6 +134,10 @@ public class TrainedModelConfig implements ToXContentObject {
return definition; return definition;
} }
public TrainedModelInput getInput() {
return input;
}
public static Builder builder() { public static Builder builder() {
return new Builder(); return new Builder();
} }
@ -160,6 +169,9 @@ public class TrainedModelConfig implements ToXContentObject {
if (metadata != null) { if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata); builder.field(METADATA.getPreferredName(), metadata);
} }
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -181,6 +193,7 @@ public class TrainedModelConfig implements ToXContentObject {
Objects.equals(createTime, that.createTime) && Objects.equals(createTime, that.createTime) &&
Objects.equals(definition, that.definition) && Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) && Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(metadata, that.metadata); Objects.equals(metadata, that.metadata);
} }
@ -193,7 +206,8 @@ public class TrainedModelConfig implements ToXContentObject {
definition, definition,
description, description,
tags, tags,
metadata); metadata,
input);
} }
@ -207,6 +221,7 @@ public class TrainedModelConfig implements ToXContentObject {
private Map<String, Object> metadata; private Map<String, Object> metadata;
private List<String> tags; private List<String> tags;
private TrainedModelDefinition definition; private TrainedModelDefinition definition;
private TrainedModelInput input;
public Builder setModelId(String modelId) { public Builder setModelId(String modelId) {
this.modelId = modelId; this.modelId = modelId;
@ -257,6 +272,11 @@ public class TrainedModelConfig implements ToXContentObject {
return this; return this;
} }
public Builder setInput(TrainedModelInput input) {
this.input = input;
return this;
}
public TrainedModelConfig build() { public TrainedModelConfig build() {
return new TrainedModelConfig( return new TrainedModelConfig(
modelId, modelId,
@ -266,7 +286,9 @@ public class TrainedModelConfig implements ToXContentObject {
createTime, createTime,
definition, definition,
tags, tags,
metadata); metadata,
input);
} }
} }
} }

View File

@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -39,7 +38,6 @@ public class TrainedModelDefinition implements ToXContentObject {
public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true, true,
@ -53,7 +51,6 @@ public class TrainedModelDefinition implements ToXContentObject {
(p, c, n) -> p.namedObject(PreProcessor.class, n, null), (p, c, n) -> p.namedObject(PreProcessor.class, n, null),
(trainedModelDefBuilder) -> {/* Does not matter client side*/ }, (trainedModelDefBuilder) -> {/* Does not matter client side*/ },
PREPROCESSORS); PREPROCESSORS);
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
} }
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException { public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
@ -62,12 +59,10 @@ public class TrainedModelDefinition implements ToXContentObject {
private final TrainedModel trainedModel; private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors; private final List<PreProcessor> preProcessors;
private final Input input;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) { TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = trainedModel; this.trainedModel = trainedModel;
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = input;
} }
@Override @Override
@ -83,9 +78,6 @@ public class TrainedModelDefinition implements ToXContentObject {
true, true,
PREPROCESSORS.getPreferredName(), PREPROCESSORS.getPreferredName(),
preProcessors); preProcessors);
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -98,10 +90,6 @@ public class TrainedModelDefinition implements ToXContentObject {
return preProcessors; return preProcessors;
} }
public Input getInput() {
return input;
}
@Override @Override
public String toString() { public String toString() {
return Strings.toString(this); return Strings.toString(this);
@ -113,20 +101,18 @@ public class TrainedModelDefinition implements ToXContentObject {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o; TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) && return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) && Objects.equals(preProcessors, that.preProcessors);
Objects.equals(input, that.input);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(trainedModel, preProcessors, input); return Objects.hash(trainedModel, preProcessors);
} }
public static class Builder { public static class Builder {
private List<PreProcessor> preProcessors; private List<PreProcessor> preProcessors;
private TrainedModel trainedModel; private TrainedModel trainedModel;
private Input input;
public Builder setPreProcessors(List<PreProcessor> preProcessors) { public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors; this.preProcessors = preProcessors;
@ -138,71 +124,14 @@ public class TrainedModelDefinition implements ToXContentObject {
return this; return this;
} }
public Builder setInput(Input input) {
this.input = input;
return this;
}
private Builder setTrainedModel(List<TrainedModel> trainedModel) { private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1; assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0)); return setTrainedModel(trainedModel.get(0));
} }
public TrainedModelDefinition build() { public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input); return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
} }
} }
public static class Input implements ToXContentObject {
public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Input((List<String>)a[0]));
static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
}
public static Input fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
private final List<String> fieldNames;
public Input(List<String> fieldNames) {
this.fieldNames = fieldNames;
}
public List<String> getFieldNames() {
return fieldNames;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldNames != null) {
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}
@Override
public int hashCode() {
return Objects.hash(fieldNames);
}
}
} }

View File

@ -0,0 +1,82 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class TrainedModelInput implements ToXContentObject {
public static final String NAME = "trained_model_config_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TrainedModelInput, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new TrainedModelInput((List<String>) a[0]));
static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
}
private final List<String> fieldNames;
public TrainedModelInput(List<String> fieldNames) {
this.fieldNames = fieldNames;
}
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
public List<String> getFieldNames() {
return fieldNames;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldNames != null) {
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelInput that = (TrainedModelInput) o;
return Objects.equals(fieldNames, that.fieldNames);
}
@Override
public int hashCode() {
return Objects.hash(fieldNames);
}
}

View File

@ -63,7 +63,8 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null : randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
} }
@Override @Override

View File

@ -64,10 +64,7 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
TargetMeanEncodingTests.createRandom())) TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors) .limit(numberOfProcessors)
.collect(Collectors.toList())) .collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom())) .setTrainedModel(randomFrom(TreeTests.createRandom()));
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList())));
} }
@Override @Override

View File

@ -0,0 +1,58 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TrainedModelInputTests extends AbstractXContentTestCase<TrainedModelInput> {
@Override
protected TrainedModelInput doParseInstance(XContentParser parser) throws IOException {
return TrainedModelInput.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
public static TrainedModelInput createRandomInput() {
return new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList()));
}
@Override
protected TrainedModelInput createTestInstance() {
return createRandomInput();
}
}

View File

@ -42,6 +42,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField DEFINITION = new ParseField("definition"); public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField TAGS = new ParseField("tags"); public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true); public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@ -61,10 +62,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
ObjectParser.ValueType.VALUE); ObjectParser.ValueType.VALUE);
parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
parser.declareObject(TrainedModelConfig.Builder::setDefinition,
(p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
DEFINITION);
parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE); parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE);
parser.declareObject(TrainedModelConfig.Builder::setInput,
(p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields),
INPUT);
return parser; return parser;
} }
@ -79,10 +80,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private final Instant createTime; private final Instant createTime;
private final List<String> tags; private final List<String> tags;
private final Map<String, Object> metadata; private final Map<String, Object> metadata;
private final TrainedModelInput input;
// TODO how to reference and store large models that will not be executed in Java???
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
// TODO Should this be lazily parsed when loading via the index???
private final TrainedModelDefinition definition; private final TrainedModelDefinition definition;
TrainedModelConfig(String modelId, TrainedModelConfig(String modelId,
@ -92,7 +91,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Instant createTime, Instant createTime,
TrainedModelDefinition definition, TrainedModelDefinition definition,
List<String> tags, List<String> tags,
Map<String, Object> metadata) { Map<String, Object> metadata,
TrainedModelInput input) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION); this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@ -101,6 +101,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.description = description; this.description = description;
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
} }
public TrainedModelConfig(StreamInput in) throws IOException { public TrainedModelConfig(StreamInput in) throws IOException {
@ -112,6 +113,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
definition = in.readOptionalWriteable(TrainedModelDefinition::new); definition = in.readOptionalWriteable(TrainedModelDefinition::new);
tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
metadata = in.readMap(); metadata = in.readMap();
input = new TrainedModelInput(in);
} }
public String getModelId() { public String getModelId() {
@ -147,6 +149,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return definition; return definition;
} }
public TrainedModelInput getInput() {
return input;
}
public static Builder builder() { public static Builder builder() {
return new Builder(); return new Builder();
} }
@ -161,6 +167,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeOptionalWriteable(definition); out.writeOptionalWriteable(definition);
out.writeCollection(tags, StreamOutput::writeString); out.writeCollection(tags, StreamOutput::writeString);
out.writeMap(metadata); out.writeMap(metadata);
input.writeTo(out);
} }
@Override @Override
@ -173,7 +180,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.field(DESCRIPTION.getPreferredName(), description); builder.field(DESCRIPTION.getPreferredName(), description);
} }
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
if (definition != null) {
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
builder.field(DEFINITION.getPreferredName(), definition); builder.field(DEFINITION.getPreferredName(), definition);
} }
builder.field(TAGS.getPreferredName(), tags); builder.field(TAGS.getPreferredName(), tags);
@ -183,6 +192,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
} }
builder.field(INPUT.getPreferredName(), input);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -204,6 +214,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Objects.equals(createTime, that.createTime) && Objects.equals(createTime, that.createTime) &&
Objects.equals(definition, that.definition) && Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) && Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(metadata, that.metadata); Objects.equals(metadata, that.metadata);
} }
@ -216,7 +227,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
definition, definition,
description, description,
tags, tags,
metadata); metadata,
input);
} }
public static class Builder { public static class Builder {
@ -228,6 +240,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private Instant createTime; private Instant createTime;
private List<String> tags = Collections.emptyList(); private List<String> tags = Collections.emptyList();
private Map<String, Object> metadata; private Map<String, Object> metadata;
private TrainedModelInput input;
private TrainedModelDefinition definition; private TrainedModelDefinition definition;
public Builder setModelId(String modelId) { public Builder setModelId(String modelId) {
@ -279,9 +292,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this; return this;
} }
public Builder setInput(TrainedModelInput input) {
this.input = input;
return this;
}
// TODO move to REST level instead of here in the builder // TODO move to REST level instead of here in the builder
public void validate() { public void validate() {
// We require a definition to be available until we support other means of supplying the definition // We require a definition to be available here even though it will be stored in a different doc
ExceptionsHelper.requireNonNull(definition, DEFINITION); ExceptionsHelper.requireNonNull(definition, DEFINITION);
ExceptionsHelper.requireNonNull(modelId, MODEL_ID); ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
@ -320,7 +338,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
createTime == null ? Instant.now() : createTime, createTime == null ? Instant.now() : createTime,
definition, definition,
tags, tags,
metadata); metadata,
input);
} }
} }
} }

View File

@ -5,16 +5,17 @@
*/ */
package org.elasticsearch.xpack.core.ml.inference; package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
@ -23,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrai
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -31,11 +33,10 @@ import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject, Writeable { public class TrainedModelDefinition implements ToXContentObject, Writeable {
public static final String NAME = "trained_mode_definition"; public static final String NAME = "trained_model_definition";
public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true); public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
@ -44,7 +45,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) { private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME, ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
ignoreUnknownFields, ignoreUnknownFields,
TrainedModelDefinition.Builder::new); TrainedModelDefinition.Builder::builderForParser);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel, parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> ignoreUnknownFields ? (p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedTrainedModel.class, n, null) : p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
@ -57,7 +58,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
p.namedObject(StrictlyParsedPreProcessor.class, n, null), p.namedObject(StrictlyParsedPreProcessor.class, n, null),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS); PREPROCESSORS);
parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT); parser.declareString(TrainedModelDefinition.Builder::setModelId, TrainedModelConfig.MODEL_ID);
return parser; return parser;
} }
@ -65,27 +66,31 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
} }
public static String docId(String modelId) {
return NAME + "-" + modelId;
}
private final TrainedModel trainedModel; private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors; private final List<PreProcessor> preProcessors;
private final Input input; private final String modelId;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) { private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, @Nullable String modelId) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL); this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = ExceptionsHelper.requireNonNull(input, INPUT); this.modelId = modelId;
} }
public TrainedModelDefinition(StreamInput in) throws IOException { public TrainedModelDefinition(StreamInput in) throws IOException {
this.trainedModel = in.readNamedWriteable(TrainedModel.class); this.trainedModel = in.readNamedWriteable(TrainedModel.class);
this.preProcessors = in.readNamedWriteableList(PreProcessor.class); this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
this.input = new Input(in); this.modelId = in.readOptionalString();
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(trainedModel); out.writeNamedWriteable(trainedModel);
out.writeNamedWriteableList(preProcessors); out.writeNamedWriteableList(preProcessors);
input.writeTo(out); out.writeOptionalString(modelId);
} }
@Override @Override
@ -101,7 +106,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
true, true,
PREPROCESSORS.getPreferredName(), PREPROCESSORS.getPreferredName(),
preProcessors); preProcessors);
builder.field(INPUT.getPreferredName(), input); if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
assert modelId != null;
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -114,10 +123,6 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return preProcessors; return preProcessors;
} }
public Input getInput() {
return input;
}
@Override @Override
public String toString() { public String toString() {
return Strings.toString(this); return Strings.toString(this);
@ -129,21 +134,21 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o; TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) && return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(input, that.input) && Objects.equals(preProcessors, that.preProcessors) &&
Objects.equals(preProcessors, that.preProcessors); Objects.equals(modelId, that.modelId);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(trainedModel, input, preProcessors); return Objects.hash(trainedModel, preProcessors, modelId);
} }
public static class Builder { public static class Builder {
private List<PreProcessor> preProcessors; private List<PreProcessor> preProcessors;
private TrainedModel trainedModel; private TrainedModel trainedModel;
private String modelId;
private boolean processorsInOrder; private boolean processorsInOrder;
private Input input;
private static Builder builderForParser() { private static Builder builderForParser() {
return new Builder(false); return new Builder(false);
@ -167,8 +172,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return this; return this;
} }
public Builder setInput(Input input) { public Builder setModelId(String modelId) {
this.input = input; this.modelId = modelId;
return this; return this;
} }
@ -188,71 +193,8 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) { if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects"); throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
} }
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input); return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.modelId);
} }
} }
public static class Input implements ToXContentObject, Writeable {
public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
public static final ConstructingObjectParser<Input, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<Input, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Input, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Input, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new Input((List<String>)a[0]));
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
return parser;
}
public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
private final List<String> fieldNames;
public Input(List<String> fieldNames) {
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
}
public Input(StreamInput in) throws IOException {
this.fieldNames = Collections.unmodifiableList(in.readStringList());
}
public List<String> getFieldNames() {
return fieldNames;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(fieldNames);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}
@Override
public int hashCode() {
return Objects.hash(fieldNames);
}
}
} }

View File

@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TrainedModelInput implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_config_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
public static final ConstructingObjectParser<TrainedModelInput, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<TrainedModelInput, Void> STRICT_PARSER = createParser(false);
private final List<String> fieldNames;
public TrainedModelInput(List<String> fieldNames) {
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
}
public TrainedModelInput(StreamInput in) throws IOException {
this.fieldNames = Collections.unmodifiableList(in.readStringList());
}
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<TrainedModelInput, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<TrainedModelInput, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new TrainedModelInput((List<String>) a[0]));
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
return parser;
}
public static TrainedModelInput fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public List<String> getFieldNames() {
return fieldNames;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(fieldNames);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelInput that = (TrainedModelInput) o;
return Objects.equals(fieldNames, that.fieldNames);
}
@Override
public int hashCode() {
return Objects.hash(fieldNames);
}
}

View File

@ -82,9 +82,8 @@ public final class Messages {
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
"Failed to serialize the trained model [{0}] with version [{1}] for storage";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created"; public static final String JOB_AUDIT_CREATED = "Job created";

View File

@ -7,15 +7,20 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
@ -28,7 +33,9 @@ import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> { public class TrainedModelConfigTests extends AbstractSerializingTestCase<TrainedModelConfig> {
@ -63,9 +70,10 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
Version.CURRENT, Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100), randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()), Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), null, // is not parsed so should not be provided
tags, tags,
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
TrainedModelInputTests.createRandomInput());
} }
@Override @Override
@ -88,6 +96,28 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
return new NamedWriteableRegistry(entries); return new NamedWriteableRegistry(entries);
} }
public void testToXContentWithParams() throws IOException {
TrainedModelConfig config = new TrainedModelConfig(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
Version.CURRENT,
randomBoolean() ? null : randomAlphaOfLength(100),
Instant.ofEpochMilli(randomNonNegativeLong()),
TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(),
Collections.emptyList(),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
TrainedModelInputTests.createRandomInput());
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("definition"));
reference = XContentHelper.toXContent(config,
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
false);
assertThat(reference.utf8ToString(), not(containsString("definition")));
}
public void testValidateWithNullDefinition() { public void testValidateWithNullDefinition() {
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate()); IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
assertThat(ex.getMessage(), equalTo("[definition] must not be null.")); assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
@ -97,7 +127,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
String modelId = "InvalidID-"; String modelId = "InvalidID-";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
} }
@ -106,7 +136,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining()); String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
ElasticsearchException ex = expectThrows(ElasticsearchException.class, ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
} }
@ -115,21 +145,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
String modelId = "simplemodel"; String modelId = "simplemodel";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setCreateTime(Instant.now()) .setCreateTime(Instant.now())
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation")); assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
ex = expectThrows(ElasticsearchException.class, ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation")); assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
ex = expectThrows(ElasticsearchException.class, ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setCreatedBy("ml_user") .setCreatedBy("ml_user")
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation")); assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));

View File

@ -58,9 +58,10 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
return field -> !field.isEmpty(); return field -> !field.isEmpty();
} }
public static TrainedModelDefinition.Builder createRandomBuilder() { public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) {
int numberOfProcessors = randomIntBetween(1, 10); int numberOfProcessors = randomIntBetween(1, 10);
return new TrainedModelDefinition.Builder() return new TrainedModelDefinition.Builder()
.setModelId(modelId)
.setPreProcessors( .setPreProcessors(
randomBoolean() ? null : randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(), Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
@ -68,22 +69,11 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
TargetMeanEncodingTests.createRandom())) TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors) .limit(numberOfProcessors)
.collect(Collectors.toList())) .collect(Collectors.toList()))
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList())))
.setTrainedModel(randomFrom(TreeTests.createRandom())); .setTrainedModel(randomFrom(TreeTests.createRandom()));
} }
private static final String ENSEMBLE_MODEL = "" + private static final String ENSEMBLE_MODEL = "" +
"{\n" + "{\n" +
" \"input\": {\n" +
" \"field_names\": [\n" +
" \"col1\",\n" +
" \"col2\",\n" +
" \"col3\",\n" +
" \"col4\"\n" +
" ]\n" +
" },\n" +
" \"preprocessors\": [\n" + " \"preprocessors\": [\n" +
" {\n" + " {\n" +
" \"one_hot_encoding\": {\n" + " \"one_hot_encoding\": {\n" +
@ -203,14 +193,6 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
"}"; "}";
private static final String TREE_MODEL = "" + private static final String TREE_MODEL = "" +
"{\n" + "{\n" +
" \"input\": {\n" +
" \"field_names\": [\n" +
" \"col1\",\n" +
" \"col2\",\n" +
" \"col3\",\n" +
" \"col4\"\n" +
" ]\n" +
" },\n" +
" \"preprocessors\": [\n" + " \"preprocessors\": [\n" +
" {\n" + " {\n" +
" \"one_hot_encoding\": {\n" + " \"one_hot_encoding\": {\n" +
@ -293,7 +275,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
@Override @Override
protected TrainedModelDefinition createTestInstance() { protected TrainedModelDefinition createTestInstance() {
return createRandomBuilder().build(); return createRandomBuilder(null).build();
} }
@Override @Override

View File

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.junit.Before;
import java.io.IOException;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TrainedModelInputTests extends AbstractSerializingTestCase<TrainedModelInput> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected TrainedModelInput doParseInstance(XContentParser parser) throws IOException {
return TrainedModelInput.fromXContent(parser, lenient);
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
public static TrainedModelInput createRandomInput() {
return new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomInt(10))
.collect(Collectors.toList()));
}
@Override
protected TrainedModelInput createTestInstance() {
return createRandomInput();
}
@Override
protected Writeable.Reader<TrainedModelInput> instanceReader() {
return TrainedModelInput::new;
}
}

View File

@ -366,7 +366,7 @@ public class AnalyticsProcessManager {
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
dataExtractorFactory.newExtractor(true)); dataExtractorFactory.newExtractor(true));
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(), resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
trainedModelProvider, auditor); trainedModelProvider, auditor, dataExtractor.getFieldNames());
return true; return true;
} }

View File

@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
@ -26,6 +27,7 @@ import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import java.time.Instant; import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -41,18 +43,21 @@ public class AnalyticsResultProcessor {
private final ProgressTracker progressTracker; private final ProgressTracker progressTracker;
private final TrainedModelProvider trainedModelProvider; private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor; private final DataFrameAnalyticsAuditor auditor;
private final List<String> fieldNames;
private final CountDownLatch completionLatch = new CountDownLatch(1); private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure; private volatile String failure;
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker, Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor) { TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor,
List<String> fieldNames) {
this.analytics = Objects.requireNonNull(analytics); this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.isProcessKilled = Objects.requireNonNull(isProcessKilled); this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
this.progressTracker = Objects.requireNonNull(progressTracker); this.progressTracker = Objects.requireNonNull(progressTracker);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor); this.auditor = Objects.requireNonNull(auditor);
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
} }
@Nullable @Nullable
@ -111,13 +116,13 @@ public class AnalyticsResultProcessor {
if (progressPercent != null) { if (progressPercent != null) {
progressTracker.analyzingPercent.set(progressPercent); progressTracker.analyzingPercent.set(progressPercent);
} }
TrainedModelDefinition inferenceModel = result.getInferenceModel(); TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
if (inferenceModel != null) { if (inferenceModelBuilder != null) {
createAndIndexInferenceModel(inferenceModel); createAndIndexInferenceModel(inferenceModelBuilder);
} }
} }
private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) { private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel); TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
CountDownLatch latch = storeTrainedModel(trainedModelConfig); CountDownLatch latch = storeTrainedModel(trainedModelConfig);
@ -131,10 +136,12 @@ public class AnalyticsResultProcessor {
} }
} }
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition inferenceModel) { private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
Instant createTime = Instant.now(); Instant createTime = Instant.now();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
TrainedModelDefinition definition = inferenceModel.setModelId(modelId).build();
return TrainedModelConfig.builder() return TrainedModelConfig.builder()
.setModelId(analytics.getId() + "-" + createTime.toEpochMilli()) .setModelId(modelId)
.setCreatedBy("data-frame-analytics") .setCreatedBy("data-frame-analytics")
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setCreateTime(createTime) .setCreateTime(createTime)
@ -142,7 +149,8 @@ public class AnalyticsResultProcessor {
.setDescription(analytics.getDescription()) .setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config", .setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setDefinition(inferenceModel) .setDefinition(definition)
.setInput(new TrainedModelInput(fieldNames))
.build(); .build();
} }

View File

@ -24,23 +24,25 @@ public class AnalyticsResult implements ToXContentObject {
public static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); public static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition) a[2])); a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition.Builder) a[2]));
static { static {
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.STRICT_PARSER.apply(p, null).build(), // TODO change back to STRICT_PARSER once native side is aligned
INFERENCE_MODEL); PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
} }
private final RowResults rowResults; private final RowResults rowResults;
private final Integer progressPercent; private final Integer progressPercent;
private final TrainedModelDefinition.Builder inferenceModelBuilder;
private final TrainedModelDefinition inferenceModel; private final TrainedModelDefinition inferenceModel;
public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition inferenceModel) { public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition.Builder inferenceModelBuilder) {
this.rowResults = rowResults; this.rowResults = rowResults;
this.progressPercent = progressPercent; this.progressPercent = progressPercent;
this.inferenceModel = inferenceModel; this.inferenceModelBuilder = inferenceModelBuilder;
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
} }
public RowResults getRowResults() { public RowResults getRowResults() {
@ -51,8 +53,8 @@ public class AnalyticsResult implements ToXContentObject {
return progressPercent; return progressPercent;
} }
public TrainedModelDefinition getInferenceModel() { public TrainedModelDefinition.Builder getInferenceModelBuilder() {
return inferenceModel; return inferenceModelBuilder;
} }
@Override @Override

View File

@ -86,6 +86,9 @@ public final class InferenceInternalIndex {
.startObject(TrainedModelConfig.CREATED_BY.getPreferredName()) .startObject(TrainedModelConfig.CREATED_BY.getPreferredName())
.field(TYPE, KEYWORD) .field(TYPE, KEYWORD)
.endObject() .endObject()
.startObject(TrainedModelConfig.INPUT.getPreferredName())
.field(ENABLED, false)
.endObject()
.startObject(TrainedModelConfig.VERSION.getPreferredName()) .startObject(TrainedModelConfig.VERSION.getPreferredName())
.field(TYPE, KEYWORD) .field(TYPE, KEYWORD)
.endObject() .endObject()
@ -95,9 +98,6 @@ public final class InferenceInternalIndex {
.startObject(TrainedModelConfig.CREATE_TIME.getPreferredName()) .startObject(TrainedModelConfig.CREATE_TIME.getPreferredName())
.field(TYPE, DATE) .field(TYPE, DATE)
.endObject() .endObject()
.startObject(TrainedModelConfig.DEFINITION.getPreferredName())
.field(ENABLED, false)
.endObject()
.startObject(TrainedModelConfig.TAGS.getPreferredName()) .startObject(TrainedModelConfig.TAGS.getPreferredName())
.field(TYPE, KEYWORD) .field(TYPE, KEYWORD)
.endObject() .endObject()

View File

@ -8,32 +8,38 @@ package org.elasticsearch.xpack.ml.inference.persistence;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.MultiSearchAction;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.MultiSearchRequestBuilder;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -51,6 +57,8 @@ public class TrainedModelProvider {
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
private final Client client; private final Client client;
private final NamedXContentRegistry xContentRegistry; private final NamedXContentRegistry xContentRegistry;
private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS =
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"));
public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) { public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) {
this.client = client; this.client = client;
@ -58,76 +66,178 @@ public class TrainedModelProvider {
} }
public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) { public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
XContentBuilder source = trainedModelConfig.toXContent(builder,
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME) if (trainedModelConfig.getDefinition() == null) {
.opType(DocWriteRequest.OpType.CREATE) listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] is required",
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) trainedModelConfig.getModelId(),
.id(trainedModelConfig.getModelId()) TrainedModelConfig.DEFINITION.getPreferredName()));
.source(source); return;
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest,
ActionListener.wrap(
r -> listener.onResponse(true),
e -> {
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e);
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
} else {
listener.onFailure(
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
RestStatus.INTERNAL_SERVER_ERROR,
e,
trainedModelConfig.getModelId()));
}
}));
} catch (IOException e) {
// not expected to happen but for the sake of completeness
listener.onFailure(new ElasticsearchParseException(
Messages.getMessage(Messages.INFERENCE_FAILED_TO_SERIALIZE_MODEL, trainedModelConfig.getModelId()),
e));
} }
BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig))
.add(createRequest(TrainedModelDefinition.docId(trainedModelConfig.getModelId()), trainedModelConfig.getDefinition()))
.request();
ActionListener<Boolean> wrappedListener = ActionListener.wrap(
listener::onResponse,
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
} else {
listener.onFailure(
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
RestStatus.INTERNAL_SERVER_ERROR,
e,
trainedModelConfig.getModelId()));
}
}
);
ActionListener<BulkResponse> bulkResponseActionListener = ActionListener.wrap(
r -> {
assert r.getItems().length == 2;
if (r.getItems()[0].isFailed()) {
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model config for inference",
trainedModelConfig.getModelId()),
r.getItems()[0].getFailure().getCause());
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
return;
}
if (r.getItems()[1].isFailed()) {
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model definition for inference",
trainedModelConfig.getModelId()),
r.getItems()[1].getFailure().getCause());
wrappedListener.onFailure(r.getItems()[1].getFailure().getCause());
return;
}
wrappedListener.onResponse(true);
},
wrappedListener::onFailure
);
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener);
} }
public void getTrainedModel(String modelId, ActionListener<TrainedModelConfig> listener) { public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery() .idsQuery()
.addIds(modelId)); .addIds(modelId));
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) MultiSearchRequestBuilder multiSearchRequestBuilder = client.prepareMultiSearch()
.setQuery(queryBuilder) .add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
// use sort to get the last .setQuery(queryBuilder)
.addSort("_index", SortOrder.DESC) // use sort to get the last
.setSize(1) .addSort("_index", SortOrder.DESC)
.request(); .setSize(1)
.request());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, if (includeDefinition) {
ActionListener.wrap( multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
searchResponse -> { .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
if (searchResponse.getHits().getHits().length == 0) { .idsQuery()
.addIds(TrainedModelDefinition.docId(modelId))))
// use sort to get the last
.addSort("_index", SortOrder.DESC)
.setSize(1)
.request());
}
ActionListener<MultiSearchResponse> multiSearchResponseActionListener = ActionListener.wrap(
multiSearchResponse -> {
TrainedModelConfig.Builder builder;
TrainedModelDefinition definition;
try {
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
} catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return;
} catch (Exception ex) {
listener.onFailure(ex);
return;
}
if (includeDefinition) {
try {
definition = handleSearchItem(multiSearchResponse.getResponses()[1],
modelId,
this::parseModelDefinitionDocLenientlyFromSource);
builder.setDefinition(definition);
} catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException( listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
} catch (Exception ex) {
listener.onFailure(ex);
return; return;
} }
BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef(); }
parseInferenceDocLenientlyFromSource(source, modelId, listener); listener.onResponse(builder.build());
}, },
listener::onFailure)); listener::onFailure
);
executeAsyncWithOrigin(client,
ML_ORIGIN,
MultiSearchAction.INSTANCE,
multiSearchRequestBuilder.request(),
multiSearchResponseActionListener);
} }
private void parseInferenceDocLenientlyFromSource(BytesReference source, private static <T> T handleSearchItem(MultiSearchResponse.Item item,
String modelId, String resourceId,
ActionListener<TrainedModelConfig> modelListener) { CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
if (item.isFailure()) {
throw item.getFailure();
}
if (item.getResponse().getHits().getHits().length == 0) {
throw new ResourceNotFoundException(resourceId);
}
return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
}
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
try (InputStream stream = source.streamInput(); try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON) XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build()); return TrainedModelConfig.fromXContent(parser, true);
} catch (Exception e) { } catch (Exception e) {
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
modelListener.onFailure(e); throw e;
}
}
private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
return TrainedModelDefinition.fromXContent(parser, true).build();
} catch (Exception e) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e);
throw e;
}
}
private IndexRequest createRequest(String docId, ToXContentObject body) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
return new IndexRequest()
.opType(DocWriteRequest.OpType.CREATE)
.id(docId)
.source(source);
} catch (IOException ex) {
// This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again
// that is not the users fault. We did something wrong and should throw.
throw ExceptionsHelper.serverError(
new ParameterizedMessage("Unexpected serialization exception for [{}]", docId).getFormattedMessage(),
ex);
} }
} }
} }

View File

@ -126,9 +126,10 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
return null; return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); List<String> expectedFieldNames = Arrays.asList("foo", "bar", "baz");
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(JOB_ID);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel))); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(); AnalyticsResultProcessor resultProcessor = createResultProcessor(expectedFieldNames);
resultProcessor.process(process); resultProcessor.process(process);
resultProcessor.awaitForCompletion(); resultProcessor.awaitForCompletion();
@ -142,7 +143,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics")); assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics"));
assertThat(storedModel.getTags(), contains(JOB_ID)); assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getDefinition(), equalTo(inferenceModel)); assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build()));
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
Map<String, Object> metadata = storedModel.getMetadata(); Map<String, Object> metadata = storedModel.getMetadata();
assertThat(metadata.size(), equalTo(1)); assertThat(metadata.size(), equalTo(1));
assertThat(metadata, hasKey("analytics_config")); assertThat(metadata, hasKey("analytics_config"));
@ -166,7 +168,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
return null; return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder("failed_model");
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel))); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(); AnalyticsResultProcessor resultProcessor = createResultProcessor();
@ -192,7 +194,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
} }
private AnalyticsResultProcessor createResultProcessor() { private AnalyticsResultProcessor createResultProcessor() {
return createResultProcessor(Collections.emptyList());
}
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider, return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider,
auditor); auditor, fieldNames);
} }
} }

View File

@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -33,7 +32,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
protected AnalyticsResult createTestInstance() { protected AnalyticsResult createTestInstance() {
RowResults rowResults = null; RowResults rowResults = null;
Integer progressPercent = null; Integer progressPercent = null;
TrainedModelDefinition inferenceModel = null; TrainedModelDefinition.Builder inferenceModel = null;
if (randomBoolean()) { if (randomBoolean()) {
rowResults = RowResultsTests.createRandom(); rowResults = RowResultsTests.createRandom();
} }
@ -41,13 +40,13 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
progressPercent = randomIntBetween(0, 100); progressPercent = randomIntBetween(0, 100);
} }
if (randomBoolean()) { if (randomBoolean()) {
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(null);
} }
return new AnalyticsResult(rowResults, progressPercent, inferenceModel); return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
} }
@Override @Override
protected AnalyticsResult doParseInstance(XContentParser parser) throws IOException { protected AnalyticsResult doParseInstance(XContentParser parser) {
return AnalyticsResult.PARSER.apply(parser, null); return AnalyticsResult.PARSER.apply(parser, null);
} }

View File

@ -6,12 +6,17 @@
package org.elasticsearch.xpack.ml.integration; package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -75,29 +80,75 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder); blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(config)); assertThat(getConfigHolder.get(), equalTo(config));
assertThat(getConfigHolder.get().getDefinition(), is(not(nullValue())));
}
public void testGetTrainedModelConfigWithoutDefinition() throws Exception {
String modelId = "test-get-trained-model-config-no-definition";
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
TrainedModelConfig config = configBuilder.build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder);
assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(),
equalTo(configBuilder.setCreateTime(config.getCreateTime()).setDefinition((TrainedModelDefinition) null).build()));
assertThat(getConfigHolder.get().getDefinition(), is(nullValue()));
} }
public void testGetMissingTrainingModelConfig() throws Exception { public void testGetMissingTrainingModelConfig() throws Exception {
String modelId = "test-get-missing-trained-model-config"; String modelId = "test-get-missing-trained-model-config";
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>(); AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder); blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
} }
private static TrainedModelConfig buildTrainedModelConfig(String modelId) { public void testGetMissingTrainingModelConfigDefinition() throws Exception {
String modelId = "test-get-missing-trained-model-config-definition";
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
client().delete(new DeleteRequest(InferenceIndexConstants.LATEST_INDEX_NAME)
.id(TrainedModelDefinition.docId(config.getModelId()))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE))
.actionGet();
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
}
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
return TrainedModelConfig.builder() return TrainedModelConfig.builder()
.setCreatedBy("ml_test") .setCreatedBy("ml_test")
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setDescription("trained model config for test") .setDescription("trained model config for test")
.setModelId(modelId) .setModelId(modelId)
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.build(); .setInput(TrainedModelInputTests.createRandomInput());
}
private static TrainedModelConfig buildTrainedModelConfig(String modelId) {
return buildTrainedModelConfigBuilder(modelId).build();
} }
@Override @Override