[7.x] [ML][Inference] Adding preprocessors to definition object (#47320) (#47370)

* [ML][Inference] Adding preprocessors to definition object (#47320)

* [ML][Inference] Adding preprocessors to definition object

* Update TrainedModelConfig.java

* adjusting for backport
This commit is contained in:
Benjamin Trent 2019-10-01 13:31:25 -04:00 committed by GitHub
parent 66116e39ba
commit f5fe5e7cd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 523 additions and 67 deletions

View File

@ -20,7 +20,6 @@ package org.elasticsearch.client.ml.inference;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
@ -31,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.time.Instant; import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -64,9 +62,8 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition, PARSER.declareObject(TrainedModelConfig.Builder::setDefinition,
(p, c, n) -> p.namedObject(TrainedModel.class, n, null), (p, c) -> TrainedModelDefinition.fromXContent(p),
(modelDocBuilder) -> { /* Noop does not matter client side */ },
DEFINITION); DEFINITION);
} }
@ -82,7 +79,7 @@ public class TrainedModelConfig implements ToXContentObject {
private final Long modelVersion; private final Long modelVersion;
private final String modelType; private final String modelType;
private final Map<String, Object> metadata; private final Map<String, Object> metadata;
private final TrainedModel definition; private final TrainedModelDefinition definition;
TrainedModelConfig(String modelId, TrainedModelConfig(String modelId,
String createdBy, String createdBy,
@ -91,7 +88,7 @@ public class TrainedModelConfig implements ToXContentObject {
Instant createdTime, Instant createdTime,
Long modelVersion, Long modelVersion,
String modelType, String modelType,
TrainedModel definition, TrainedModelDefinition definition,
Map<String, Object> metadata) { Map<String, Object> metadata) {
this.modelId = modelId; this.modelId = modelId;
this.createdBy = createdBy; this.createdBy = createdBy;
@ -136,7 +133,7 @@ public class TrainedModelConfig implements ToXContentObject {
return metadata; return metadata;
} }
public TrainedModel getDefinition() { public TrainedModelDefinition getDefinition() {
return definition; return definition;
} }
@ -169,11 +166,7 @@ public class TrainedModelConfig implements ToXContentObject {
builder.field(MODEL_TYPE.getPreferredName(), modelType); builder.field(MODEL_TYPE.getPreferredName(), modelType);
} }
if (definition != null) { if (definition != null) {
NamedXContentObjectHelper.writeNamedObjects(builder, builder.field(DEFINITION.getPreferredName(), definition);
params,
false,
DEFINITION.getPreferredName(),
Collections.singletonList(definition));
} }
if (metadata != null) { if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata); builder.field(METADATA.getPreferredName(), metadata);
@ -227,7 +220,7 @@ public class TrainedModelConfig implements ToXContentObject {
private Long modelVersion; private Long modelVersion;
private String modelType; private String modelType;
private Map<String, Object> metadata; private Map<String, Object> metadata;
private TrainedModel definition; private TrainedModelDefinition.Builder definition;
public Builder setModelId(String modelId) { public Builder setModelId(String modelId) {
this.modelId = modelId; this.modelId = modelId;
@ -273,16 +266,11 @@ public class TrainedModelConfig implements ToXContentObject {
return this; return this;
} }
public Builder setDefinition(TrainedModel definition) { public Builder setDefinition(TrainedModelDefinition.Builder definition) {
this.definition = definition; this.definition = definition;
return this; return this;
} }
private Builder setDefinition(List<TrainedModel> definition) {
assert definition.size() == 1;
return setDefinition(definition.get(0));
}
public TrainedModelConfig build() { public TrainedModelConfig build() {
return new TrainedModelConfig( return new TrainedModelConfig(
modelId, modelId,
@ -292,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject {
createdTime, createdTime,
modelVersion, modelVersion,
modelType, modelType,
definition, definition == null ? null : definition.build(),
metadata); metadata);
} }
} }

View File

@ -0,0 +1,137 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject {
public static final String NAME = "trained_model_doc";
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
TrainedModelDefinition.Builder::new);
static {
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> p.namedObject(TrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter client side*/ },
TRAINED_MODEL);
PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
PREPROCESSORS);
}
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = trainedModel;
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
TRAINED_MODEL.getPreferredName(),
Collections.singletonList(trainedModel));
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
builder.endObject();
return builder;
}
public TrainedModel getTrainedModel() {
return trainedModel;
}
public List<PreProcessor> getPreProcessors() {
return preProcessors;
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) ;
}
@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
}
public static class Builder {
private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
return this;
}
public Builder setTrainedModel(TrainedModel trainedModel) {
this.trainedModel = trainedModel;
return this;
}
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0));
}
public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
}
}
}

View File

@ -18,13 +18,13 @@
*/ */
package org.elasticsearch.client.ml.inference.preprocessing; package org.elasticsearch.client.ml.inference.preprocessing;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.client.ml.inference.NamedXContentObject;
/** /**
* Describes a pre-processor for a defined machine learning model * Describes a pre-processor for a defined machine learning model
*/ */
public interface PreProcessor extends ToXContentObject { public interface PreProcessor extends NamedXContentObject {
/** /**
* @return The name of the pre-processor * @return The name of the pre-processor

View File

@ -19,7 +19,6 @@
package org.elasticsearch.client.ml.inference; package org.elasticsearch.client.ml.inference;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
@ -61,7 +60,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
Instant.ofEpochMilli(randomNonNegativeLong()), Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : randomNonNegativeLong(), randomBoolean() ? null : randomNonNegativeLong(),
randomAlphaOfLength(10), randomAlphaOfLength(10),
randomFrom(TreeTests.createRandom()), randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
} }

View File

@ -0,0 +1,83 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TrainedModelDefinitionTests extends AbstractXContentTestCase<TrainedModelDefinition> {
@Override
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
return TrainedModelDefinition.fromXContent(parser).build();
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
public static TrainedModelDefinition.Builder createRandomBuilder() {
int numberOfProcessors = randomIntBetween(1, 10);
return new TrainedModelDefinition.Builder()
.setPreProcessors(
randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom()));
}
@Override
protected TrainedModelDefinition createTestInstance() {
return createRandomBuilder().build();
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -17,18 +17,13 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.common.time.TimeUtils; import org.elasticsearch.xpack.core.common.time.TimeUtils;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException; import java.io.IOException;
import java.time.Instant; import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -65,11 +60,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition, parser.declareObject(TrainedModelConfig.Builder::setDefinition,
(p, c, n) -> ignoreUnknownFields ? (p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
DEFINITION); DEFINITION);
return parser; return parser;
} }
@ -94,7 +86,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
// TODO how to reference and store large models that will not be executed in Java??? // TODO how to reference and store large models that will not be executed in Java???
// Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something // Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
// TODO Should this be lazily parsed when loading via the index??? // TODO Should this be lazily parsed when loading via the index???
private final TrainedModel definition; private final TrainedModelDefinition definition;
TrainedModelConfig(String modelId, TrainedModelConfig(String modelId,
String createdBy, String createdBy,
Version version, Version version,
@ -102,7 +94,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Instant createdTime, Instant createdTime,
Long modelVersion, Long modelVersion,
String modelType, String modelType,
TrainedModel definition, TrainedModelDefinition definition,
Map<String, Object> metadata) { Map<String, Object> metadata) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
@ -123,7 +115,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
createdTime = in.readInstant(); createdTime = in.readInstant();
modelVersion = in.readVLong(); modelVersion = in.readVLong();
modelType = in.readString(); modelType = in.readString();
definition = in.readOptionalNamedWriteable(TrainedModel.class); definition = in.readOptionalWriteable(TrainedModelDefinition::new);
metadata = in.readMap(); metadata = in.readMap();
} }
@ -160,7 +152,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
} }
@Nullable @Nullable
public TrainedModel getDefinition() { public TrainedModelDefinition getDefinition() {
return definition; return definition;
} }
@ -177,7 +169,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeInstant(createdTime); out.writeInstant(createdTime);
out.writeVLong(modelVersion); out.writeVLong(modelVersion);
out.writeString(modelType); out.writeString(modelType);
out.writeOptionalNamedWriteable(definition); out.writeOptionalWriteable(definition);
out.writeMap(metadata); out.writeMap(metadata);
} }
@ -194,11 +186,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.field(MODEL_VERSION.getPreferredName(), modelVersion); builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
builder.field(MODEL_TYPE.getPreferredName(), modelType); builder.field(MODEL_TYPE.getPreferredName(), modelType);
if (definition != null) { if (definition != null) {
NamedXContentObjectHelper.writeNamedObjects(builder, builder.field(DEFINITION.getPreferredName(), definition);
params,
false,
DEFINITION.getPreferredName(),
Collections.singletonList(definition));
} }
if (metadata != null) { if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata); builder.field(METADATA.getPreferredName(), metadata);
@ -241,7 +229,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
modelVersion); modelVersion);
} }
public static class Builder { public static class Builder {
private String modelId; private String modelId;
@ -252,7 +239,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private Long modelVersion; private Long modelVersion;
private String modelType; private String modelType;
private Map<String, Object> metadata; private Map<String, Object> metadata;
private TrainedModel definition; private TrainedModelDefinition.Builder definition;
public Builder setModelId(String modelId) { public Builder setModelId(String modelId) {
this.modelId = modelId; this.modelId = modelId;
@ -298,19 +285,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this; return this;
} }
public Builder setDefinition(TrainedModel definition) { public Builder setDefinition(TrainedModelDefinition.Builder definition) {
this.definition = definition; this.definition = definition;
return this; return this;
} }
private Builder setDefinition(List<TrainedModel> definition) {
if (definition.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
DEFINITION.getPreferredName());
}
return setDefinition(definition.get(0));
}
// TODO move to REST level instead of here in the builder // TODO move to REST level instead of here in the builder
public void validate() { public void validate() {
// We require a definition to be available until we support other means of supplying the definition // We require a definition to be available until we support other means of supplying the definition
@ -352,7 +331,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
createdTime, createdTime,
modelVersion, modelVersion,
modelType, modelType,
definition, definition == null ? null : definition.build(),
metadata); metadata);
} }
@ -365,7 +344,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Instant.now(), Instant.now(),
modelVersion, modelVersion,
modelType, modelType,
definition, definition == null ? null : definition.build(),
metadata); metadata);
} }
} }

View File

@ -0,0 +1,176 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_doc";
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
public static final ObjectParser<TrainedModelDefinition.Builder, Void> STRICT_PARSER = createParser(false);
private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
ignoreUnknownFields,
TrainedModelDefinition.Builder::new);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
(modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS);
return parser;
}
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = trainedModel;
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
}
public TrainedModelDefinition(StreamInput in) throws IOException {
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(trainedModel);
out.writeNamedWriteableList(preProcessors);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
TRAINED_MODEL.getPreferredName(),
Collections.singletonList(trainedModel));
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
builder.endObject();
return builder;
}
public TrainedModel getTrainedModel() {
return trainedModel;
}
public List<PreProcessor> getPreProcessors() {
return preProcessors;
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) ;
}
@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
}
public static class Builder {
private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
private boolean processorsInOrder;
private static Builder builderForParser() {
return new Builder(false);
}
private Builder(boolean processorsInOrder) {
this.processorsInOrder = processorsInOrder;
}
public Builder() {
this(true);
}
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
return this;
}
public Builder setTrainedModel(TrainedModel trainedModel) {
this.trainedModel = trainedModel;
return this;
}
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
if (trainedModel.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
TRAINED_MODEL.getPreferredName());
}
return setTrainedModel(trainedModel.get(0));
}
private void setProcessorsInOrder(boolean value) {
this.processorsInOrder = value;
}
public TrainedModelDefinition build() {
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
}
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
}
}
}

View File

@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.MlStrings;
import org.junit.Before; import org.junit.Before;
@ -65,7 +64,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
Instant.ofEpochMilli(randomNonNegativeLong()), Instant.ofEpochMilli(randomNonNegativeLong()),
randomBoolean() ? null : randomNonNegativeLong(), randomBoolean() ? null : randomNonNegativeLong(),
randomAlphaOfLength(10), randomAlphaOfLength(10),
randomBoolean() ? null : randomFrom(TreeTests.createRandom()), randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
} }
@ -97,14 +96,18 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
public void testValidateWithInvalidID() { public void testValidateWithInvalidID() {
String modelId = "InvalidID-"; String modelId = "InvalidID-";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate()); () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
} }
public void testValidateWithLongID() { public void testValidateWithLongID() {
String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining()); String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
ElasticsearchException ex = expectThrows(ElasticsearchException.class, ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate()); () -> TrainedModelConfig.builder()
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
} }
@ -112,21 +115,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
String modelId = "simplemodel"; String modelId = "simplemodel";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(randomFrom(TreeTests.createRandom())) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedTime(Instant.now()) .setCreatedTime(Instant.now())
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation")); assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
ex = expectThrows(ElasticsearchException.class, ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(randomFrom(TreeTests.createRandom())) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setVersion(Version.CURRENT) .setVersion(Version.CURRENT)
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation")); assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
ex = expectThrows(ElasticsearchException.class, ex = expectThrows(ElasticsearchException.class,
() -> TrainedModelConfig.builder() () -> TrainedModelConfig.builder()
.setDefinition(randomFrom(TreeTests.createRandom())) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setCreatedBy("ml_user") .setCreatedBy("ml_user")
.setModelId(modelId).validate()); .setModelId(modelId).validate());
assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation")); assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));

View File

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
return TrainedModelDefinition.fromXContent(parser, lenient).build();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return field -> !field.isEmpty();
}
public static TrainedModelDefinition.Builder createRandomBuilder() {
int numberOfProcessors = randomIntBetween(1, 10);
return new TrainedModelDefinition.Builder()
.setPreProcessors(
randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom()));
}
@Override
protected TrainedModelDefinition createTestInstance() {
return createRandomBuilder().build();
}
@Override
protected Writeable.Reader<TrainedModelDefinition> instanceReader() {
return TrainedModelDefinition::new;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
}

View File

@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -93,7 +93,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
return TrainedModelConfig.builder() return TrainedModelConfig.builder()
.setCreatedBy("ml_test") .setCreatedBy("ml_test")
.setDefinition(TreeTests.createRandom()) .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.setDescription("trained model config for test") .setDescription("trained model config for test")
.setModelId(modelId) .setModelId(modelId)
.setModelType("binary_decision_tree") .setModelType("binary_decision_tree")