* [ML][Inference] Adding preprocessors to definition object (#47320) * [ML][Inference] Adding preprocessors to definition object * Update TrainedModelConfig.java * adjusting for backport
This commit is contained in:
parent
66116e39ba
commit
f5fe5e7cd6
|
@ -20,7 +20,6 @@ package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.client.common.TimeUtil;
|
import org.elasticsearch.client.common.TimeUtil;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
|
@ -31,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@ -64,9 +62,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
||||||
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
||||||
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||||
PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
|
PARSER.declareObject(TrainedModelConfig.Builder::setDefinition,
|
||||||
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
|
(p, c) -> TrainedModelDefinition.fromXContent(p),
|
||||||
(modelDocBuilder) -> { /* Noop does not matter client side */ },
|
|
||||||
DEFINITION);
|
DEFINITION);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,7 +79,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
private final Long modelVersion;
|
private final Long modelVersion;
|
||||||
private final String modelType;
|
private final String modelType;
|
||||||
private final Map<String, Object> metadata;
|
private final Map<String, Object> metadata;
|
||||||
private final TrainedModel definition;
|
private final TrainedModelDefinition definition;
|
||||||
|
|
||||||
TrainedModelConfig(String modelId,
|
TrainedModelConfig(String modelId,
|
||||||
String createdBy,
|
String createdBy,
|
||||||
|
@ -91,7 +88,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
Instant createdTime,
|
Instant createdTime,
|
||||||
Long modelVersion,
|
Long modelVersion,
|
||||||
String modelType,
|
String modelType,
|
||||||
TrainedModel definition,
|
TrainedModelDefinition definition,
|
||||||
Map<String, Object> metadata) {
|
Map<String, Object> metadata) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
this.createdBy = createdBy;
|
this.createdBy = createdBy;
|
||||||
|
@ -136,7 +133,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainedModel getDefinition() {
|
public TrainedModelDefinition getDefinition() {
|
||||||
return definition;
|
return definition;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,11 +166,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
||||||
}
|
}
|
||||||
if (definition != null) {
|
if (definition != null) {
|
||||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
builder.field(DEFINITION.getPreferredName(), definition);
|
||||||
params,
|
|
||||||
false,
|
|
||||||
DEFINITION.getPreferredName(),
|
|
||||||
Collections.singletonList(definition));
|
|
||||||
}
|
}
|
||||||
if (metadata != null) {
|
if (metadata != null) {
|
||||||
builder.field(METADATA.getPreferredName(), metadata);
|
builder.field(METADATA.getPreferredName(), metadata);
|
||||||
|
@ -227,7 +220,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
private Long modelVersion;
|
private Long modelVersion;
|
||||||
private String modelType;
|
private String modelType;
|
||||||
private Map<String, Object> metadata;
|
private Map<String, Object> metadata;
|
||||||
private TrainedModel definition;
|
private TrainedModelDefinition.Builder definition;
|
||||||
|
|
||||||
public Builder setModelId(String modelId) {
|
public Builder setModelId(String modelId) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
|
@ -273,16 +266,11 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setDefinition(TrainedModel definition) {
|
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
|
||||||
this.definition = definition;
|
this.definition = definition;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Builder setDefinition(List<TrainedModel> definition) {
|
|
||||||
assert definition.size() == 1;
|
|
||||||
return setDefinition(definition.get(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
public TrainedModelConfig build() {
|
public TrainedModelConfig build() {
|
||||||
return new TrainedModelConfig(
|
return new TrainedModelConfig(
|
||||||
modelId,
|
modelId,
|
||||||
|
@ -292,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
||||||
createdTime,
|
createdTime,
|
||||||
modelVersion,
|
modelVersion,
|
||||||
modelType,
|
modelType,
|
||||||
definition,
|
definition == null ? null : definition.build(),
|
||||||
metadata);
|
metadata);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class TrainedModelDefinition implements ToXContentObject {
|
||||||
|
|
||||||
|
public static final String NAME = "trained_model_doc";
|
||||||
|
|
||||||
|
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||||
|
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||||
|
|
||||||
|
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||||
|
true,
|
||||||
|
TrainedModelDefinition.Builder::new);
|
||||||
|
static {
|
||||||
|
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
|
||||||
|
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
|
||||||
|
(modelDocBuilder) -> { /* Noop does not matter client side*/ },
|
||||||
|
TRAINED_MODEL);
|
||||||
|
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
||||||
|
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
|
||||||
|
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
|
||||||
|
PREPROCESSORS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
|
||||||
|
return PARSER.parse(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final TrainedModel trainedModel;
|
||||||
|
private final List<PreProcessor> preProcessors;
|
||||||
|
|
||||||
|
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||||
|
this.trainedModel = trainedModel;
|
||||||
|
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||||
|
params,
|
||||||
|
false,
|
||||||
|
TRAINED_MODEL.getPreferredName(),
|
||||||
|
Collections.singletonList(trainedModel));
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||||
|
params,
|
||||||
|
true,
|
||||||
|
PREPROCESSORS.getPreferredName(),
|
||||||
|
preProcessors);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModel getTrainedModel() {
|
||||||
|
return trainedModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<PreProcessor> getPreProcessors() {
|
||||||
|
return preProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Strings.toString(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||||
|
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||||
|
Objects.equals(preProcessors, that.preProcessors) ;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(trainedModel, preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
private List<PreProcessor> preProcessors;
|
||||||
|
private TrainedModel trainedModel;
|
||||||
|
|
||||||
|
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||||
|
this.preProcessors = preProcessors;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setTrainedModel(TrainedModel trainedModel) {
|
||||||
|
this.trainedModel = trainedModel;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||||
|
assert trainedModel.size() == 1;
|
||||||
|
return setTrainedModel(trainedModel.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModelDefinition build() {
|
||||||
|
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -18,13 +18,13 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||||
|
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.client.ml.inference.NamedXContentObject;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Describes a pre-processor for a defined machine learning model
|
* Describes a pre-processor for a defined machine learning model
|
||||||
*/
|
*/
|
||||||
public interface PreProcessor extends ToXContentObject {
|
public interface PreProcessor extends NamedXContentObject {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return The name of the pre-processor
|
* @return The name of the pre-processor
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
package org.elasticsearch.client.ml.inference;
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
@ -61,7 +60,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
||||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||||
randomBoolean() ? null : randomNonNegativeLong(),
|
randomBoolean() ? null : randomNonNegativeLong(),
|
||||||
randomAlphaOfLength(10),
|
randomAlphaOfLength(10),
|
||||||
randomFrom(TreeTests.createRandom()),
|
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
|
public class TrainedModelDefinitionTests extends AbstractXContentTestCase<TrainedModelDefinition> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return TrainedModelDefinition.fromXContent(parser).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> !field.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
||||||
|
int numberOfProcessors = randomIntBetween(1, 10);
|
||||||
|
return new TrainedModelDefinition.Builder()
|
||||||
|
.setPreProcessors(
|
||||||
|
randomBoolean() ? null :
|
||||||
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
|
OneHotEncodingTests.createRandom(),
|
||||||
|
TargetMeanEncodingTests.createRandom()))
|
||||||
|
.limit(numberOfProcessors)
|
||||||
|
.collect(Collectors.toList()))
|
||||||
|
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelDefinition createTestInstance() {
|
||||||
|
return createRandomBuilder().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -17,18 +17,13 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@ -65,11 +60,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
||||||
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
||||||
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||||
parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
|
parser.declareObject(TrainedModelConfig.Builder::setDefinition,
|
||||||
(p, c, n) -> ignoreUnknownFields ?
|
(p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
|
||||||
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
|
||||||
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
|
|
||||||
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
|
|
||||||
DEFINITION);
|
DEFINITION);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
@ -94,7 +86,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
// TODO how to reference and store large models that will not be executed in Java???
|
// TODO how to reference and store large models that will not be executed in Java???
|
||||||
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
|
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
|
||||||
// TODO Should this be lazily parsed when loading via the index???
|
// TODO Should this be lazily parsed when loading via the index???
|
||||||
private final TrainedModel definition;
|
private final TrainedModelDefinition definition;
|
||||||
TrainedModelConfig(String modelId,
|
TrainedModelConfig(String modelId,
|
||||||
String createdBy,
|
String createdBy,
|
||||||
Version version,
|
Version version,
|
||||||
|
@ -102,7 +94,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
Instant createdTime,
|
Instant createdTime,
|
||||||
Long modelVersion,
|
Long modelVersion,
|
||||||
String modelType,
|
String modelType,
|
||||||
TrainedModel definition,
|
TrainedModelDefinition definition,
|
||||||
Map<String, Object> metadata) {
|
Map<String, Object> metadata) {
|
||||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||||
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
||||||
|
@ -123,7 +115,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
createdTime = in.readInstant();
|
createdTime = in.readInstant();
|
||||||
modelVersion = in.readVLong();
|
modelVersion = in.readVLong();
|
||||||
modelType = in.readString();
|
modelType = in.readString();
|
||||||
definition = in.readOptionalNamedWriteable(TrainedModel.class);
|
definition = in.readOptionalWriteable(TrainedModelDefinition::new);
|
||||||
metadata = in.readMap();
|
metadata = in.readMap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +152,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
public TrainedModel getDefinition() {
|
public TrainedModelDefinition getDefinition() {
|
||||||
return definition;
|
return definition;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,7 +169,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
out.writeInstant(createdTime);
|
out.writeInstant(createdTime);
|
||||||
out.writeVLong(modelVersion);
|
out.writeVLong(modelVersion);
|
||||||
out.writeString(modelType);
|
out.writeString(modelType);
|
||||||
out.writeOptionalNamedWriteable(definition);
|
out.writeOptionalWriteable(definition);
|
||||||
out.writeMap(metadata);
|
out.writeMap(metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,11 +186,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
|
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
|
||||||
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
||||||
if (definition != null) {
|
if (definition != null) {
|
||||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
builder.field(DEFINITION.getPreferredName(), definition);
|
||||||
params,
|
|
||||||
false,
|
|
||||||
DEFINITION.getPreferredName(),
|
|
||||||
Collections.singletonList(definition));
|
|
||||||
}
|
}
|
||||||
if (metadata != null) {
|
if (metadata != null) {
|
||||||
builder.field(METADATA.getPreferredName(), metadata);
|
builder.field(METADATA.getPreferredName(), metadata);
|
||||||
|
@ -241,7 +229,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
modelVersion);
|
modelVersion);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
|
||||||
private String modelId;
|
private String modelId;
|
||||||
|
@ -252,7 +239,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
private Long modelVersion;
|
private Long modelVersion;
|
||||||
private String modelType;
|
private String modelType;
|
||||||
private Map<String, Object> metadata;
|
private Map<String, Object> metadata;
|
||||||
private TrainedModel definition;
|
private TrainedModelDefinition.Builder definition;
|
||||||
|
|
||||||
public Builder setModelId(String modelId) {
|
public Builder setModelId(String modelId) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
|
@ -298,19 +285,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setDefinition(TrainedModel definition) {
|
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
|
||||||
this.definition = definition;
|
this.definition = definition;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Builder setDefinition(List<TrainedModel> definition) {
|
|
||||||
if (definition.size() != 1) {
|
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
|
|
||||||
DEFINITION.getPreferredName());
|
|
||||||
}
|
|
||||||
return setDefinition(definition.get(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO move to REST level instead of here in the builder
|
// TODO move to REST level instead of here in the builder
|
||||||
public void validate() {
|
public void validate() {
|
||||||
// We require a definition to be available until we support other means of supplying the definition
|
// We require a definition to be available until we support other means of supplying the definition
|
||||||
|
@ -352,7 +331,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
createdTime,
|
createdTime,
|
||||||
modelVersion,
|
modelVersion,
|
||||||
modelType,
|
modelType,
|
||||||
definition,
|
definition == null ? null : definition.build(),
|
||||||
metadata);
|
metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,7 +344,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
Instant.now(),
|
Instant.now(),
|
||||||
modelVersion,
|
modelVersion,
|
||||||
modelType,
|
modelType,
|
||||||
definition,
|
definition == null ? null : definition.build(),
|
||||||
metadata);
|
metadata);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
|
public static final String NAME = "trained_model_doc";
|
||||||
|
|
||||||
|
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||||
|
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||||
|
|
||||||
|
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||||
|
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||||
|
public static final ObjectParser<TrainedModelDefinition.Builder, Void> STRICT_PARSER = createParser(false);
|
||||||
|
|
||||||
|
private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
|
||||||
|
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
|
||||||
|
ignoreUnknownFields,
|
||||||
|
TrainedModelDefinition.Builder::new);
|
||||||
|
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
|
||||||
|
(p, c, n) -> ignoreUnknownFields ?
|
||||||
|
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
||||||
|
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
|
||||||
|
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
|
||||||
|
TRAINED_MODEL);
|
||||||
|
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
||||||
|
(p, c, n) -> ignoreUnknownFields ?
|
||||||
|
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
|
||||||
|
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
||||||
|
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
||||||
|
PREPROCESSORS);
|
||||||
|
return parser;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
||||||
|
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final TrainedModel trainedModel;
|
||||||
|
private final List<PreProcessor> preProcessors;
|
||||||
|
|
||||||
|
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||||
|
this.trainedModel = trainedModel;
|
||||||
|
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModelDefinition(StreamInput in) throws IOException {
|
||||||
|
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
|
||||||
|
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeNamedWriteable(trainedModel);
|
||||||
|
out.writeNamedWriteableList(preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||||
|
params,
|
||||||
|
false,
|
||||||
|
TRAINED_MODEL.getPreferredName(),
|
||||||
|
Collections.singletonList(trainedModel));
|
||||||
|
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||||
|
params,
|
||||||
|
true,
|
||||||
|
PREPROCESSORS.getPreferredName(),
|
||||||
|
preProcessors);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModel getTrainedModel() {
|
||||||
|
return trainedModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<PreProcessor> getPreProcessors() {
|
||||||
|
return preProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Strings.toString(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||||
|
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||||
|
Objects.equals(preProcessors, that.preProcessors) ;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(trainedModel, preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
|
||||||
|
private List<PreProcessor> preProcessors;
|
||||||
|
private TrainedModel trainedModel;
|
||||||
|
private boolean processorsInOrder;
|
||||||
|
|
||||||
|
private static Builder builderForParser() {
|
||||||
|
return new Builder(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Builder(boolean processorsInOrder) {
|
||||||
|
this.processorsInOrder = processorsInOrder;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder() {
|
||||||
|
this(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||||
|
this.preProcessors = preProcessors;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setTrainedModel(TrainedModel trainedModel) {
|
||||||
|
this.trainedModel = trainedModel;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||||
|
if (trainedModel.size() != 1) {
|
||||||
|
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
|
||||||
|
TRAINED_MODEL.getPreferredName());
|
||||||
|
}
|
||||||
|
return setTrainedModel(trainedModel.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setProcessorsInOrder(boolean value) {
|
||||||
|
this.processorsInOrder = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModelDefinition build() {
|
||||||
|
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
|
||||||
|
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
|
||||||
|
}
|
||||||
|
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.SearchModule;
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -65,7 +64,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||||
randomBoolean() ? null : randomNonNegativeLong(),
|
randomBoolean() ? null : randomNonNegativeLong(),
|
||||||
randomAlphaOfLength(10),
|
randomAlphaOfLength(10),
|
||||||
randomBoolean() ? null : randomFrom(TreeTests.createRandom()),
|
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,14 +96,18 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
public void testValidateWithInvalidID() {
|
public void testValidateWithInvalidID() {
|
||||||
String modelId = "InvalidID-";
|
String modelId = "InvalidID-";
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
|
() -> TrainedModelConfig.builder()
|
||||||
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||||
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testValidateWithLongID() {
|
public void testValidateWithLongID() {
|
||||||
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
|
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
|
() -> TrainedModelConfig.builder()
|
||||||
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||||
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,21 +115,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
String modelId = "simplemodel";
|
String modelId = "simplemodel";
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(randomFrom(TreeTests.createRandom()))
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||||
.setCreatedTime(Instant.now())
|
.setCreatedTime(Instant.now())
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
|
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
|
||||||
|
|
||||||
ex = expectThrows(ElasticsearchException.class,
|
ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(randomFrom(TreeTests.createRandom()))
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||||
.setVersion(Version.CURRENT)
|
.setVersion(Version.CURRENT)
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
|
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
|
||||||
|
|
||||||
ex = expectThrows(ElasticsearchException.class,
|
ex = expectThrows(ElasticsearchException.class,
|
||||||
() -> TrainedModelConfig.builder()
|
() -> TrainedModelConfig.builder()
|
||||||
.setDefinition(randomFrom(TreeTests.createRandom()))
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||||
.setCreatedBy("ml_user")
|
.setCreatedBy("ml_user")
|
||||||
.setModelId(modelId).validate());
|
.setModelId(modelId).validate());
|
||||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.search.SearchModule;
|
||||||
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
|
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
|
||||||
|
|
||||||
|
private boolean lenient;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void chooseStrictOrLenient() {
|
||||||
|
lenient = randomBoolean();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return TrainedModelDefinition.fromXContent(parser, lenient).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return lenient;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> !field.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
||||||
|
int numberOfProcessors = randomIntBetween(1, 10);
|
||||||
|
return new TrainedModelDefinition.Builder()
|
||||||
|
.setPreProcessors(
|
||||||
|
randomBoolean() ? null :
|
||||||
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
|
OneHotEncodingTests.createRandom(),
|
||||||
|
TargetMeanEncodingTests.createRandom()))
|
||||||
|
.limit(numberOfProcessors)
|
||||||
|
.collect(Collectors.toList()))
|
||||||
|
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
protected TrainedModelDefinition createTestInstance() {
|
||||||
|
return createRandomBuilder().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<TrainedModelDefinition> instanceReader() {
|
||||||
|
return TrainedModelDefinition::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
|
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||||
|
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||||
|
return new NamedWriteableRegistry(entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.search.SearchModule;
|
import org.elasticsearch.search.SearchModule;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
@ -93,7 +93,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
|
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
|
||||||
return TrainedModelConfig.builder()
|
return TrainedModelConfig.builder()
|
||||||
.setCreatedBy("ml_test")
|
.setCreatedBy("ml_test")
|
||||||
.setDefinition(TreeTests.createRandom())
|
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||||
.setDescription("trained model config for test")
|
.setDescription("trained model config for test")
|
||||||
.setModelId(modelId)
|
.setModelId(modelId)
|
||||||
.setModelType("binary_decision_tree")
|
.setModelType("binary_decision_tree")
|
||||||
|
|
Loading…
Reference in New Issue