* [ML][Inference] Adding preprocessors to definition object (#47320) * [ML][Inference] Adding preprocessors to definition object * Update TrainedModelConfig.java * adjusting for backport
This commit is contained in:
parent
66116e39ba
commit
f5fe5e7cd6
|
@ -20,7 +20,6 @@ package org.elasticsearch.client.ml.inference;
|
|||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.client.common.TimeUtil;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
|
@ -31,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -64,9 +62,8 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
||||
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||
PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
|
||||
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
|
||||
(modelDocBuilder) -> { /* Noop does not matter client side */ },
|
||||
PARSER.declareObject(TrainedModelConfig.Builder::setDefinition,
|
||||
(p, c) -> TrainedModelDefinition.fromXContent(p),
|
||||
DEFINITION);
|
||||
}
|
||||
|
||||
|
@ -82,7 +79,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private final Long modelVersion;
|
||||
private final String modelType;
|
||||
private final Map<String, Object> metadata;
|
||||
private final TrainedModel definition;
|
||||
private final TrainedModelDefinition definition;
|
||||
|
||||
TrainedModelConfig(String modelId,
|
||||
String createdBy,
|
||||
|
@ -91,7 +88,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
Instant createdTime,
|
||||
Long modelVersion,
|
||||
String modelType,
|
||||
TrainedModel definition,
|
||||
TrainedModelDefinition definition,
|
||||
Map<String, Object> metadata) {
|
||||
this.modelId = modelId;
|
||||
this.createdBy = createdBy;
|
||||
|
@ -136,7 +133,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return metadata;
|
||||
}
|
||||
|
||||
public TrainedModel getDefinition() {
|
||||
public TrainedModelDefinition getDefinition() {
|
||||
return definition;
|
||||
}
|
||||
|
||||
|
@ -169,11 +166,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
||||
}
|
||||
if (definition != null) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
false,
|
||||
DEFINITION.getPreferredName(),
|
||||
Collections.singletonList(definition));
|
||||
builder.field(DEFINITION.getPreferredName(), definition);
|
||||
}
|
||||
if (metadata != null) {
|
||||
builder.field(METADATA.getPreferredName(), metadata);
|
||||
|
@ -227,7 +220,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
private Long modelVersion;
|
||||
private String modelType;
|
||||
private Map<String, Object> metadata;
|
||||
private TrainedModel definition;
|
||||
private TrainedModelDefinition.Builder definition;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -273,16 +266,11 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setDefinition(TrainedModel definition) {
|
||||
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
|
||||
this.definition = definition;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setDefinition(List<TrainedModel> definition) {
|
||||
assert definition.size() == 1;
|
||||
return setDefinition(definition.get(0));
|
||||
}
|
||||
|
||||
public TrainedModelConfig build() {
|
||||
return new TrainedModelConfig(
|
||||
modelId,
|
||||
|
@ -292,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject {
|
|||
createdTime,
|
||||
modelVersion,
|
||||
modelType,
|
||||
definition,
|
||||
definition == null ? null : definition.build(),
|
||||
metadata);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TrainedModelDefinition implements ToXContentObject {
|
||||
|
||||
public static final String NAME = "trained_model_doc";
|
||||
|
||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||
|
||||
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
|
||||
true,
|
||||
TrainedModelDefinition.Builder::new);
|
||||
static {
|
||||
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
|
||||
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
|
||||
(modelDocBuilder) -> { /* Noop does not matter client side*/ },
|
||||
TRAINED_MODEL);
|
||||
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
||||
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
|
||||
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
|
||||
PREPROCESSORS);
|
||||
}
|
||||
|
||||
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
|
||||
return PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
private final TrainedModel trainedModel;
|
||||
private final List<PreProcessor> preProcessors;
|
||||
|
||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||
this.trainedModel = trainedModel;
|
||||
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
false,
|
||||
TRAINED_MODEL.getPreferredName(),
|
||||
Collections.singletonList(trainedModel));
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
true,
|
||||
PREPROCESSORS.getPreferredName(),
|
||||
preProcessors);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public TrainedModel getTrainedModel() {
|
||||
return trainedModel;
|
||||
}
|
||||
|
||||
public List<PreProcessor> getPreProcessors() {
|
||||
return preProcessors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||
Objects.equals(preProcessors, that.preProcessors) ;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(trainedModel, preProcessors);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private List<PreProcessor> preProcessors;
|
||||
private TrainedModel trainedModel;
|
||||
|
||||
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||
this.preProcessors = preProcessors;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTrainedModel(TrainedModel trainedModel) {
|
||||
this.trainedModel = trainedModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||
assert trainedModel.size() == 1;
|
||||
return setTrainedModel(trainedModel.get(0));
|
||||
}
|
||||
|
||||
public TrainedModelDefinition build() {
|
||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -18,13 +18,13 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.inference.preprocessing;
|
||||
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.client.ml.inference.NamedXContentObject;
|
||||
|
||||
|
||||
/**
|
||||
* Describes a pre-processor for a defined machine learning model
|
||||
*/
|
||||
public interface PreProcessor extends ToXContentObject {
|
||||
public interface PreProcessor extends NamedXContentObject {
|
||||
|
||||
/**
|
||||
* @return The name of the pre-processor
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -61,7 +60,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
|
|||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomAlphaOfLength(10),
|
||||
randomFrom(TreeTests.createRandom()),
|
||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference;
|
||||
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class TrainedModelDefinitionTests extends AbstractXContentTestCase<TrainedModelDefinition> {
|
||||
|
||||
@Override
|
||||
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
|
||||
return TrainedModelDefinition.fromXContent(parser).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
||||
int numberOfProcessors = randomIntBetween(1, 10);
|
||||
return new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||
OneHotEncodingTests.createRandom(),
|
||||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(numberOfProcessors)
|
||||
.collect(Collectors.toList()))
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TrainedModelDefinition createTestInstance() {
|
||||
return createRandomBuilder().build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
}
|
|
@ -17,18 +17,13 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.common.time.TimeUtils;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -65,11 +60,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
|
||||
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
|
||||
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
|
||||
parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
|
||||
(p, c, n) -> ignoreUnknownFields ?
|
||||
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
||||
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
|
||||
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
|
||||
parser.declareObject(TrainedModelConfig.Builder::setDefinition,
|
||||
(p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
|
||||
DEFINITION);
|
||||
return parser;
|
||||
}
|
||||
|
@ -94,7 +86,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
// TODO how to reference and store large models that will not be executed in Java???
|
||||
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
|
||||
// TODO Should this be lazily parsed when loading via the index???
|
||||
private final TrainedModel definition;
|
||||
private final TrainedModelDefinition definition;
|
||||
TrainedModelConfig(String modelId,
|
||||
String createdBy,
|
||||
Version version,
|
||||
|
@ -102,7 +94,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
Instant createdTime,
|
||||
Long modelVersion,
|
||||
String modelType,
|
||||
TrainedModel definition,
|
||||
TrainedModelDefinition definition,
|
||||
Map<String, Object> metadata) {
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
|
||||
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
|
||||
|
@ -123,7 +115,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
createdTime = in.readInstant();
|
||||
modelVersion = in.readVLong();
|
||||
modelType = in.readString();
|
||||
definition = in.readOptionalNamedWriteable(TrainedModel.class);
|
||||
definition = in.readOptionalWriteable(TrainedModelDefinition::new);
|
||||
metadata = in.readMap();
|
||||
}
|
||||
|
||||
|
@ -160,7 +152,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
}
|
||||
|
||||
@Nullable
|
||||
public TrainedModel getDefinition() {
|
||||
public TrainedModelDefinition getDefinition() {
|
||||
return definition;
|
||||
}
|
||||
|
||||
|
@ -177,7 +169,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
out.writeInstant(createdTime);
|
||||
out.writeVLong(modelVersion);
|
||||
out.writeString(modelType);
|
||||
out.writeOptionalNamedWriteable(definition);
|
||||
out.writeOptionalWriteable(definition);
|
||||
out.writeMap(metadata);
|
||||
}
|
||||
|
||||
|
@ -194,11 +186,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
|
||||
builder.field(MODEL_TYPE.getPreferredName(), modelType);
|
||||
if (definition != null) {
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
false,
|
||||
DEFINITION.getPreferredName(),
|
||||
Collections.singletonList(definition));
|
||||
builder.field(DEFINITION.getPreferredName(), definition);
|
||||
}
|
||||
if (metadata != null) {
|
||||
builder.field(METADATA.getPreferredName(), metadata);
|
||||
|
@ -241,7 +229,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
modelVersion);
|
||||
}
|
||||
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String modelId;
|
||||
|
@ -252,7 +239,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
private Long modelVersion;
|
||||
private String modelType;
|
||||
private Map<String, Object> metadata;
|
||||
private TrainedModel definition;
|
||||
private TrainedModelDefinition.Builder definition;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -298,19 +285,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setDefinition(TrainedModel definition) {
|
||||
public Builder setDefinition(TrainedModelDefinition.Builder definition) {
|
||||
this.definition = definition;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setDefinition(List<TrainedModel> definition) {
|
||||
if (definition.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
|
||||
DEFINITION.getPreferredName());
|
||||
}
|
||||
return setDefinition(definition.get(0));
|
||||
}
|
||||
|
||||
// TODO move to REST level instead of here in the builder
|
||||
public void validate() {
|
||||
// We require a definition to be available until we support other means of supplying the definition
|
||||
|
@ -352,7 +331,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
createdTime,
|
||||
modelVersion,
|
||||
modelType,
|
||||
definition,
|
||||
definition == null ? null : definition.build(),
|
||||
metadata);
|
||||
}
|
||||
|
||||
|
@ -365,7 +344,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
Instant.now(),
|
||||
modelVersion,
|
||||
modelType,
|
||||
definition,
|
||||
definition == null ? null : definition.build(),
|
||||
metadata);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TrainedModelDefinition implements ToXContentObject, Writeable {
|
||||
|
||||
public static final String NAME = "trained_model_doc";
|
||||
|
||||
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
|
||||
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
public static final ObjectParser<TrainedModelDefinition.Builder, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
TrainedModelDefinition.Builder::new);
|
||||
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
|
||||
(p, c, n) -> ignoreUnknownFields ?
|
||||
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
|
||||
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
|
||||
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
|
||||
TRAINED_MODEL);
|
||||
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
|
||||
(p, c, n) -> ignoreUnknownFields ?
|
||||
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
|
||||
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
|
||||
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
|
||||
PREPROCESSORS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException {
|
||||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
private final TrainedModel trainedModel;
|
||||
private final List<PreProcessor> preProcessors;
|
||||
|
||||
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
|
||||
this.trainedModel = trainedModel;
|
||||
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
|
||||
}
|
||||
|
||||
public TrainedModelDefinition(StreamInput in) throws IOException {
|
||||
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
|
||||
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeNamedWriteable(trainedModel);
|
||||
out.writeNamedWriteableList(preProcessors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
false,
|
||||
TRAINED_MODEL.getPreferredName(),
|
||||
Collections.singletonList(trainedModel));
|
||||
NamedXContentObjectHelper.writeNamedObjects(builder,
|
||||
params,
|
||||
true,
|
||||
PREPROCESSORS.getPreferredName(),
|
||||
preProcessors);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public TrainedModel getTrainedModel() {
|
||||
return trainedModel;
|
||||
}
|
||||
|
||||
public List<PreProcessor> getPreProcessors() {
|
||||
return preProcessors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinition that = (TrainedModelDefinition) o;
|
||||
return Objects.equals(trainedModel, that.trainedModel) &&
|
||||
Objects.equals(preProcessors, that.preProcessors) ;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(trainedModel, preProcessors);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private List<PreProcessor> preProcessors;
|
||||
private TrainedModel trainedModel;
|
||||
private boolean processorsInOrder;
|
||||
|
||||
private static Builder builderForParser() {
|
||||
return new Builder(false);
|
||||
}
|
||||
|
||||
private Builder(boolean processorsInOrder) {
|
||||
this.processorsInOrder = processorsInOrder;
|
||||
}
|
||||
|
||||
public Builder() {
|
||||
this(true);
|
||||
}
|
||||
|
||||
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
|
||||
this.preProcessors = preProcessors;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setTrainedModel(TrainedModel trainedModel) {
|
||||
this.trainedModel = trainedModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
|
||||
if (trainedModel.size() != 1) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
|
||||
TRAINED_MODEL.getPreferredName());
|
||||
}
|
||||
return setTrainedModel(trainedModel.get(0));
|
||||
}
|
||||
|
||||
private void setProcessorsInOrder(boolean value) {
|
||||
this.processorsInOrder = value;
|
||||
}
|
||||
|
||||
public TrainedModelDefinition build() {
|
||||
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
|
||||
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
|
||||
}
|
||||
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||
import org.junit.Before;
|
||||
|
@ -65,7 +64,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
Instant.ofEpochMilli(randomNonNegativeLong()),
|
||||
randomBoolean() ? null : randomNonNegativeLong(),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomFrom(TreeTests.createRandom()),
|
||||
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
|
||||
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
|
||||
}
|
||||
|
||||
|
@ -97,14 +96,18 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
public void testValidateWithInvalidID() {
|
||||
String modelId = "InvalidID-";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
|
||||
}
|
||||
|
||||
public void testValidateWithLongID() {
|
||||
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
|
||||
}
|
||||
|
||||
|
@ -112,21 +115,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
|||
String modelId = "simplemodel";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setDefinition(randomFrom(TreeTests.createRandom()))
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setCreatedTime(Instant.now())
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setDefinition(randomFrom(TreeTests.createRandom()))
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setVersion(Version.CURRENT)
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
|
||||
|
||||
ex = expectThrows(ElasticsearchException.class,
|
||||
() -> TrainedModelConfig.builder()
|
||||
.setDefinition(randomFrom(TreeTests.createRandom()))
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setCreatedBy("ml_user")
|
||||
.setModelId(modelId).validate());
|
||||
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
|
||||
|
||||
private boolean lenient;
|
||||
|
||||
@Before
|
||||
public void chooseStrictOrLenient() {
|
||||
lenient = randomBoolean();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
|
||||
return TrainedModelDefinition.fromXContent(parser, lenient).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return lenient;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
||||
int numberOfProcessors = randomIntBetween(1, 10);
|
||||
return new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||
OneHotEncodingTests.createRandom(),
|
||||
TargetMeanEncodingTests.createRandom()))
|
||||
.limit(numberOfProcessors)
|
||||
.collect(Collectors.toList()))
|
||||
.setTrainedModel(randomFrom(TreeTests.createRandom()));
|
||||
}
|
||||
@Override
|
||||
protected TrainedModelDefinition createTestInstance() {
|
||||
return createRandomBuilder().build();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<TrainedModelDefinition> instanceReader() {
|
||||
return TrainedModelDefinition::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
||||
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(entries);
|
||||
}
|
||||
|
||||
}
|
|
@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|||
import org.elasticsearch.search.SearchModule;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
@ -93,7 +93,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
|
||||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
.setDefinition(TreeTests.createRandom())
|
||||
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
|
||||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId)
|
||||
.setModelType("binary_decision_tree")
|
||||
|
|
Loading…
Reference in New Issue