[7.x] [ML][Inference] adjusting definition object schema and validation (#47447) (#47673)

* [ML][Inference] adjusting definition object schema and validation (#47447)

* [ML][Inference] adjusting definition object schema and validation

* finalizing schema and fixing inference npe

* addressing PR comments

* fixing for backport
This commit is contained in:
Benjamin Trent 2019-10-08 07:11:05 -04:00 committed by GitHub
parent ce91ba7c25
commit d33dbf82d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 459 additions and 63 deletions

View File

@ -22,6 +22,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -38,6 +39,7 @@ public class TrainedModelDefinition implements ToXContentObject {
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
@ -51,6 +53,7 @@ public class TrainedModelDefinition implements ToXContentObject {
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
PREPROCESSORS);
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
}
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
@ -59,10 +62,12 @@ public class TrainedModelDefinition implements ToXContentObject {
private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private final Input input;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
this.trainedModel = trainedModel;
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = input;
}
@Override
@ -78,6 +83,9 @@ public class TrainedModelDefinition implements ToXContentObject {
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
builder.endObject();
return builder;
}
@ -90,6 +98,10 @@ public class TrainedModelDefinition implements ToXContentObject {
return preProcessors;
}
public Input getInput() {
return input;
}
@Override
public String toString() {
return Strings.toString(this);
@ -101,18 +113,20 @@ public class TrainedModelDefinition implements ToXContentObject {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) ;
Objects.equals(preProcessors, that.preProcessors) &&
Objects.equals(input, that.input);
}
@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
return Objects.hash(trainedModel, preProcessors, input);
}
public static class Builder {
private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
private Input input;
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
@ -124,14 +138,71 @@ public class TrainedModelDefinition implements ToXContentObject {
return this;
}
public Builder setInput(Input input) {
this.input = input;
return this;
}
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0));
}
public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
}
}
public static class Input implements ToXContentObject {
public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Input((List<String>)a[0]));
static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
}
public static Input fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}
private final List<String> fieldNames;
public Input(List<String> fieldNames) {
this.fieldNames = fieldNames;
}
public List<String> getFieldNames() {
return fieldNames;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldNames != null) {
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}
@Override
public int hashCode() {
return Objects.hash(fieldNames);
}
}
}

View File

@ -39,7 +39,7 @@ public class TargetMeanEncoding implements PreProcessor {
public static final String NAME = "target_mean_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MEANS = new ParseField("target_means");
public static final ParseField TARGET_MAP = new ParseField("target_map");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
@SuppressWarnings("unchecked")
@ -52,7 +52,7 @@ public class TargetMeanEncoding implements PreProcessor {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MEANS);
TARGET_MAP);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
}
@ -110,7 +110,7 @@ public class TargetMeanEncoding implements PreProcessor {
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
builder.field(TARGET_MAP.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
builder.endObject();
return builder;

View File

@ -64,7 +64,10 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom()));
.setTrainedModel(randomFrom(TreeTests.createRandom()))
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList())));
}
@Override

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -30,10 +31,11 @@ import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_doc";
public static final String NAME = "trained_mode_definition";
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
@ -55,6 +57,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS);
parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT);
return parser;
}
@ -64,21 +67,25 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private final Input input;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = trainedModel;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
}
public TrainedModelDefinition(StreamInput in) throws IOException {
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
this.input = new Input(in);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(trainedModel);
out.writeNamedWriteableList(preProcessors);
input.writeTo(out);
}
@Override
@ -94,6 +101,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
builder.field(INPUT.getPreferredName(), input);
builder.endObject();
return builder;
}
@ -106,6 +114,10 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return preProcessors;
}
public Input getInput() {
return input;
}
@Override
public String toString() {
return Strings.toString(this);
@ -117,12 +129,13 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) ;
Objects.equals(input, that.input) &&
Objects.equals(preProcessors, that.preProcessors);
}
@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
return Objects.hash(trainedModel, input, preProcessors);
}
public static class Builder {
@ -130,6 +143,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
private boolean processorsInOrder;
private Input input;
private static Builder builderForParser() {
return new Builder(false);
@ -153,6 +167,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return this;
}
public Builder setInput(Input input) {
this.input = input;
return this;
}
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
if (trainedModel.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
@ -169,8 +188,71 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
}
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
}
}
public static class Input implements ToXContentObject, Writeable {
public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
public static final ConstructingObjectParser<Input, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<Input, Void> STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Input, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Input, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new Input((List<String>)a[0]));
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
return parser;
}
public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
private final List<String> fieldNames;
public Input(List<String> fieldNames) {
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
}
public Input(StreamInput in) throws IOException {
this.fieldNames = Collections.unmodifiableList(in.readStringList());
}
public List<String> getFieldNames() {
return fieldNames;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(fieldNames);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}
@Override
public int hashCode() {
return Objects.hash(fieldNames);
}
}
}

View File

@ -28,7 +28,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
public static final ParseField NAME = new ParseField("target_mean_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MEANS = new ParseField("target_means");
public static final ParseField TARGET_MAP = new ParseField("target_map");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
public static final ConstructingObjectParser<TargetMeanEncoding, Void> STRICT_PARSER = createParser(false);
@ -44,7 +44,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
parser.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MEANS);
TARGET_MAP);
parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
return parser;
}
@ -65,7 +65,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
this.field = ExceptionsHelper.requireNonNull(field, FIELD);
this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME);
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MEANS));
this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MAP));
this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE);
}
@ -136,7 +136,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
builder.field(TARGET_MAP.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
builder.endObject();
return builder;

View File

@ -107,14 +107,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
@Override
public double infer(Map<String, Object> fields) {
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
return infer(features);
List<Double> processedInferences = inferAndProcess(fields);
return outputAggregator.aggregate(processedInferences);
}
@Override
public double infer(List<Double> fields) {
List<Double> processedInferences = inferAndProcess(fields);
return outputAggregator.aggregate(processedInferences);
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
}
@Override
@ -128,17 +127,12 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
return classificationProbability(features);
return inferAndProcess(fields);
}
@Override
public List<Double> classificationProbability(List<Double> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
return inferAndProcess(fields);
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
}
@Override
@ -146,7 +140,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return classificationLabels;
}
private List<Double> inferAndProcess(List<Double> fields) {
private List<Double> inferAndProcess(Map<String, Object> fields) {
List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
return outputAggregator.processValues(modelInferences);
}
@ -210,15 +204,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
@Override
public void validate() {
if (this.featureNames != null) {
if (this.models.stream()
.anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) {
throw ExceptionsHelper.badRequestException(
"[{}] must be the same and in the same order for each of the {}",
FEATURE_NAMES.getPreferredName(),
TRAINED_MODELS.getPreferredName());
}
}
if (outputAggregator.expectedValueSize() != null &&
outputAggregator.expectedValueSize() != models.size()) {
throw ExceptionsHelper.badRequestException(

View File

@ -106,7 +106,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
@Override
public double infer(Map<String, Object> fields) {
List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
List<Double> features = featureNames.stream().map(f ->
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null
).collect(Collectors.toList());
return infer(features);
}
@ -146,7 +148,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()));
List<Double> features = featureNames.stream().map(f ->
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null)
.collect(Collectors.toList());
return classificationProbability(features);
}
@Override

View File

@ -8,13 +8,18 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.junit.Before;
@ -26,6 +31,8 @@ import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
@ -61,8 +68,229 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList())))
.setTrainedModel(randomFrom(TreeTests.createRandom()));
}
private static final String ENSEMBLE_MODEL = "" +
"{\n" +
" \"input\": {\n" +
" \"field_names\": [\n" +
" \"col1\",\n" +
" \"col2\",\n" +
" \"col3\",\n" +
" \"col4\"\n" +
" ]\n" +
" },\n" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_sum\": {\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"regression\",\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
private static final String TREE_MODEL = "" +
"{\n" +
" \"input\": {\n" +
" \"field_names\": [\n" +
" \"col1\",\n" +
" \"col2\",\n" +
" \"col3\",\n" +
" \"col4\"\n" +
" ]\n" +
" },\n" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
"}";
public void testEnsembleSchemaDeserialization() throws IOException {
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ENSEMBLE_MODEL);
TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build();
assertThat(definition.getTrainedModel().getClass(), equalTo(Ensemble.class));
}
public void testTreeSchemaDeserialization() throws IOException {
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, TREE_MODEL);
TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build();
assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
}
@Override
protected TrainedModelDefinition createTestInstance() {
return createRandomBuilder().build();

View File

@ -26,6 +26,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
@ -108,25 +109,6 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
return new NamedWriteableRegistry(entries);
}
public void testEnsembleWithModelsThatHaveDifferentFeatureNames() {
List<String> featureNames = Arrays.asList("foo", "bar", "baz", "farequote");
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder().setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6)))
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder().setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6)))
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
}
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
List<String> featureNames = Arrays.asList("foo", "bar");
int numberOfModels = 5;
@ -279,6 +261,17 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
}
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
expected = Arrays.asList(0.6899744811, 0.3100255188);
probabilities = ensemble.classificationProbability(featureMap);
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
}
}
public void testClassificationInference() {
@ -336,6 +329,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertEquals(0.0, ensemble.infer(featureMap), 0.00001);
}
public void testRegressionInference() {
@ -394,6 +393,12 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {

View File

@ -17,12 +17,14 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
@ -118,19 +120,26 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.3, tree.infer(featureMap), 0.00001);
assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.1, tree.infer(featureMap), 0.00001);
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.2, tree.infer(featureMap), 0.00001);
assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001));
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
}
public void testTreeClassificationProbability() {
@ -162,6 +171,13 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertEquals(1.0, tree.infer(featureMap), 0.00001);
}
public void testTreeWithNullRoot() {