From d33dbf82d47ee4962cdf6c3616dbf9f57d5e3282 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 8 Oct 2019 07:11:05 -0400 Subject: [PATCH] [7.x] [ML][Inference] adjusting definition object schema and validation (#47447) (#47673) * [ML][Inference] adjusting definition object schema and validation (#47447) * [ML][Inference] adjusting definition object schema and validation * finalizing schema and fixing inference npe * addressing PR comments * fixing for backport --- .../ml/inference/TrainedModelDefinition.java | 79 +++++- .../preprocessing/TargetMeanEncoding.java | 6 +- .../TrainedModelDefinitionTests.java | 5 +- .../ml/inference/TrainedModelDefinition.java | 94 +++++++- .../preprocessing/TargetMeanEncoding.java | 8 +- .../trainedmodel/ensemble/Ensemble.java | 27 +-- .../ml/inference/trainedmodel/tree/Tree.java | 10 +- .../TrainedModelDefinitionTests.java | 228 ++++++++++++++++++ .../trainedmodel/ensemble/EnsembleTests.java | 43 ++-- .../trainedmodel/tree/TreeTests.java | 22 +- 10 files changed, 459 insertions(+), 63 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java index 7b564a9e684..dec834fa328 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java @@ -22,6 +22,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -38,6 +39,7 @@ public class TrainedModelDefinition implements ToXContentObject { public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); + public static final ParseField INPUT = new ParseField("input"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -51,6 +53,7 @@ public class TrainedModelDefinition implements ToXContentObject { (p, c, n) -> p.namedObject(PreProcessor.class, n, null), (trainedModelDefBuilder) -> {/* Does not matter client side*/ }, PREPROCESSORS); + PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT); } public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException { @@ -59,10 +62,12 @@ public class TrainedModelDefinition implements ToXContentObject { private final TrainedModel trainedModel; private final List preProcessors; + private final Input input; - TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) { + TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, Input input) { this.trainedModel = trainedModel; this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); + this.input = input; } @Override @@ -78,6 +83,9 @@ public class TrainedModelDefinition implements ToXContentObject { true, PREPROCESSORS.getPreferredName(), preProcessors); + if (input != null) { + builder.field(INPUT.getPreferredName(), input); + } builder.endObject(); return builder; } @@ -90,6 +98,10 @@ public class TrainedModelDefinition implements ToXContentObject { return preProcessors; } + public Input getInput() { + return input; + } + @Override public String toString() { return Strings.toString(this); @@ -101,18 +113,20 @@ public class TrainedModelDefinition implements ToXContentObject { if (o == null || getClass() != o.getClass()) return false; TrainedModelDefinition that = (TrainedModelDefinition) o; return Objects.equals(trainedModel, that.trainedModel) && - Objects.equals(preProcessors, that.preProcessors) ; + Objects.equals(preProcessors, that.preProcessors) && + Objects.equals(input, that.input); } @Override public int hashCode() { - return Objects.hash(trainedModel, preProcessors); + return Objects.hash(trainedModel, preProcessors, input); } public static class Builder { private List preProcessors; private TrainedModel trainedModel; + private Input input; public Builder setPreProcessors(List preProcessors) { this.preProcessors = preProcessors; @@ -124,14 +138,71 @@ public class TrainedModelDefinition implements ToXContentObject { return this; } + public Builder setInput(Input input) { + this.input = input; + return this; + } + private Builder setTrainedModel(List trainedModel) { assert trainedModel.size() == 1; return setTrainedModel(trainedModel.get(0)); } public TrainedModelDefinition build() { - return new TrainedModelDefinition(this.trainedModel, this.preProcessors); + return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input); } } + public static class Input implements ToXContentObject { + + public static final String NAME = "trained_mode_definition_input"; + public static final ParseField FIELD_NAMES = new ParseField("field_names"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new Input((List)a[0])); + static { + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); + } + + public static Input fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final List fieldNames; + + public Input(List fieldNames) { + this.fieldNames = fieldNames; + } + + public List getFieldNames() { + return fieldNames; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fieldNames != null) { + builder.field(FIELD_NAMES.getPreferredName(), fieldNames); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o; + return Objects.equals(fieldNames, that.fieldNames); + } + + @Override + public int hashCode() { + return Objects.hash(fieldNames); + } + + } + } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncoding.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncoding.java index bb29924b98e..18203f33018 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncoding.java @@ -39,7 +39,7 @@ public class TargetMeanEncoding implements PreProcessor { public static final String NAME = "target_mean_encoding"; public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); - public static final ParseField TARGET_MEANS = new ParseField("target_means"); + public static final ParseField TARGET_MAP = new ParseField("target_map"); public static final ParseField DEFAULT_VALUE = new ParseField("default_value"); @SuppressWarnings("unchecked") @@ -52,7 +52,7 @@ public class TargetMeanEncoding implements PreProcessor { PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), - TARGET_MEANS); + TARGET_MAP); PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE); } @@ -110,7 +110,7 @@ public class TargetMeanEncoding implements PreProcessor { builder.startObject(); builder.field(FIELD.getPreferredName(), field); builder.field(FEATURE_NAME.getPreferredName(), featureName); - builder.field(TARGET_MEANS.getPreferredName(), meanMap); + builder.field(TARGET_MAP.getPreferredName(), meanMap); builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue); builder.endObject(); return builder; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java index 7c31e472477..dda04640ab0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java @@ -64,7 +64,10 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase randomAlphaOfLength(10)) + .limit(randomLongBetween(1, 10)) + .collect(Collectors.toList()))); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index 6daa530e027..f85c184646e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -30,10 +31,11 @@ import java.util.Objects; public class TrainedModelDefinition implements ToXContentObject, Writeable { - public static final String NAME = "trained_model_doc"; + public static final String NAME = "trained_mode_definition"; public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); + public static final ParseField INPUT = new ParseField("input"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -55,6 +57,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { p.namedObject(StrictlyParsedPreProcessor.class, n, null), (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), PREPROCESSORS); + parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT); return parser; } @@ -64,21 +67,25 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { private final TrainedModel trainedModel; private final List preProcessors; + private final Input input; - TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) { - this.trainedModel = trainedModel; + TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, Input input) { + this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL); this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); + this.input = ExceptionsHelper.requireNonNull(input, INPUT); } public TrainedModelDefinition(StreamInput in) throws IOException { this.trainedModel = in.readNamedWriteable(TrainedModel.class); this.preProcessors = in.readNamedWriteableList(PreProcessor.class); + this.input = new Input(in); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(trainedModel); out.writeNamedWriteableList(preProcessors); + input.writeTo(out); } @Override @@ -94,6 +101,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { true, PREPROCESSORS.getPreferredName(), preProcessors); + builder.field(INPUT.getPreferredName(), input); builder.endObject(); return builder; } @@ -106,6 +114,10 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { return preProcessors; } + public Input getInput() { + return input; + } + @Override public String toString() { return Strings.toString(this); @@ -117,12 +129,13 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { if (o == null || getClass() != o.getClass()) return false; TrainedModelDefinition that = (TrainedModelDefinition) o; return Objects.equals(trainedModel, that.trainedModel) && - Objects.equals(preProcessors, that.preProcessors) ; + Objects.equals(input, that.input) && + Objects.equals(preProcessors, that.preProcessors); } @Override public int hashCode() { - return Objects.hash(trainedModel, preProcessors); + return Objects.hash(trainedModel, input, preProcessors); } public static class Builder { @@ -130,6 +143,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { private List preProcessors; private TrainedModel trainedModel; private boolean processorsInOrder; + private Input input; private static Builder builderForParser() { return new Builder(false); @@ -153,6 +167,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { return this; } + public Builder setInput(Input input) { + this.input = input; + return this; + } + private Builder setTrainedModel(List trainedModel) { if (trainedModel.size() != 1) { throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.", @@ -169,8 +188,71 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) { throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects"); } - return new TrainedModelDefinition(this.trainedModel, this.preProcessors); + return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input); } } + public static class Input implements ToXContentObject, Writeable { + + public static final String NAME = "trained_mode_definition_input"; + public static final ParseField FIELD_NAMES = new ParseField("field_names"); + + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new Input((List)a[0])); + parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); + return parser; + } + + public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + private final List fieldNames; + + public Input(List fieldNames) { + this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES)); + } + + public Input(StreamInput in) throws IOException { + this.fieldNames = Collections.unmodifiableList(in.readStringList()); + } + + public List getFieldNames() { + return fieldNames; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(fieldNames); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_NAMES.getPreferredName(), fieldNames); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o; + return Objects.equals(fieldNames, that.fieldNames); + } + + @Override + public int hashCode() { + return Objects.hash(fieldNames); + } + + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index ebce49db957..d8f413b3b17 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -28,7 +28,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly public static final ParseField NAME = new ParseField("target_mean_encoding"); public static final ParseField FIELD = new ParseField("field"); public static final ParseField FEATURE_NAME = new ParseField("feature_name"); - public static final ParseField TARGET_MEANS = new ParseField("target_means"); + public static final ParseField TARGET_MAP = new ParseField("target_map"); public static final ParseField DEFAULT_VALUE = new ParseField("default_value"); public static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -44,7 +44,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), - TARGET_MEANS); + TARGET_MAP); parser.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE); return parser; } @@ -65,7 +65,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly public TargetMeanEncoding(String field, String featureName, Map meanMap, Double defaultValue) { this.field = ExceptionsHelper.requireNonNull(field, FIELD); this.featureName = ExceptionsHelper.requireNonNull(featureName, FEATURE_NAME); - this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MEANS)); + this.meanMap = Collections.unmodifiableMap(ExceptionsHelper.requireNonNull(meanMap, TARGET_MAP)); this.defaultValue = ExceptionsHelper.requireNonNull(defaultValue, DEFAULT_VALUE); } @@ -136,7 +136,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly builder.startObject(); builder.field(FIELD.getPreferredName(), field); builder.field(FEATURE_NAME.getPreferredName(), featureName); - builder.field(TARGET_MEANS.getPreferredName(), meanMap); + builder.field(TARGET_MAP.getPreferredName(), meanMap); builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue); builder.endObject(); return builder; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 7f2a7cc9a02..5e5199c2405 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -107,14 +107,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai @Override public double infer(Map fields) { - List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); - return infer(features); + List processedInferences = inferAndProcess(fields); + return outputAggregator.aggregate(processedInferences); } @Override public double infer(List fields) { - List processedInferences = inferAndProcess(fields); - return outputAggregator.aggregate(processedInferences); + throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); } @Override @@ -128,17 +127,12 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai throw new UnsupportedOperationException( "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); } - List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); - return classificationProbability(features); + return inferAndProcess(fields); } @Override public List classificationProbability(List fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - return inferAndProcess(fields); + throw new UnsupportedOperationException("Ensemble requires map containing field names and values"); } @Override @@ -146,7 +140,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai return classificationLabels; } - private List inferAndProcess(List fields) { + private List inferAndProcess(Map fields) { List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); return outputAggregator.processValues(modelInferences); } @@ -210,15 +204,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai @Override public void validate() { - if (this.featureNames != null) { - if (this.models.stream() - .anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) { - throw ExceptionsHelper.badRequestException( - "[{}] must be the same and in the same order for each of the {}", - FEATURE_NAMES.getPreferredName(), - TRAINED_MODELS.getPreferredName()); - } - } if (outputAggregator.expectedValueSize() != null && outputAggregator.expectedValueSize() != models.size()) { throw ExceptionsHelper.badRequestException( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 5dca29d5843..3a91ec0cd86 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -106,7 +106,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM @Override public double infer(Map fields) { - List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + List features = featureNames.stream().map(f -> + fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null + ).collect(Collectors.toList()); return infer(features); } @@ -146,7 +148,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM throw new UnsupportedOperationException( "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); } - return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList())); + List features = featureNames.stream().map(f -> + fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null) + .collect(Collectors.toList()); + + return classificationProbability(features); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 870a8c7049d..c31638b0479 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -8,13 +8,18 @@ package org.elasticsearch.xpack.core.ml.inference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.junit.Before; @@ -26,6 +31,8 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.Matchers.equalTo; + public class TrainedModelDefinitionTests extends AbstractSerializingTestCase { @@ -61,8 +68,229 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase randomAlphaOfLength(10)) + .limit(randomLongBetween(1, 10)) + .collect(Collectors.toList()))) .setTrainedModel(randomFrom(TreeTests.createRandom())); } + + private static final String ENSEMBLE_MODEL = "" + + "{\n" + + " \"input\": {\n" + + " \"field_names\": [\n" + + " \"col1\",\n" + + " \"col2\",\n" + + " \"col3\",\n" + + " \"col4\"\n" + + " ]\n" + + " },\n" + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"ensemble\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"aggregate_output\": {\n" + + " \"weighted_sum\": {\n" + + " \"weights\": [\n" + + " 0.5,\n" + + " 0.5\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"target_type\": \"regression\",\n" + + " \"trained_models\": [\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col2_encoded\",\n" + + " \"col3_encoded\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + private static final String TREE_MODEL = "" + + "{\n" + + " \"input\": {\n" + + " \"field_names\": [\n" + + " \"col1\",\n" + + " \"col2\",\n" + + " \"col3\",\n" + + " \"col4\"\n" + + " ]\n" + + " },\n" + + " \"preprocessors\": [\n" + + " {\n" + + " \"one_hot_encoding\": {\n" + + " \"field\": \"col1\",\n" + + " \"hot_map\": {\n" + + " \"male\": \"col1_male\",\n" + + " \"female\": \"col1_female\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"target_mean_encoding\": {\n" + + " \"field\": \"col2\",\n" + + " \"feature_name\": \"col2_encoded\",\n" + + " \"target_map\": {\n" + + " \"S\": 5.0,\n" + + " \"M\": 10.0,\n" + + " \"L\": 20\n" + + " },\n" + + " \"default_value\": 5.0\n" + + " }\n" + + " },\n" + + " {\n" + + " \"frequency_encoding\": {\n" + + " \"field\": \"col3\",\n" + + " \"feature_name\": \"col3_encoded\",\n" + + " \"frequency_map\": {\n" + + " \"none\": 0.75,\n" + + " \"true\": 0.10,\n" + + " \"false\": 0.15\n" + + " }\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"trained_model\": {\n" + + " \"tree\": {\n" + + " \"feature_names\": [\n" + + " \"col1_male\",\n" + + " \"col1_female\",\n" + + " \"col4\"\n" + + " ],\n" + + " \"tree_structure\": [\n" + + " {\n" + + " \"node_index\": 0,\n" + + " \"split_feature\": 0,\n" + + " \"split_gain\": 12.0,\n" + + " \"threshold\": 10.0,\n" + + " \"decision_type\": \"lte\",\n" + + " \"default_left\": true,\n" + + " \"left_child\": 1,\n" + + " \"right_child\": 2\n" + + " },\n" + + " {\n" + + " \"node_index\": 1,\n" + + " \"leaf_value\": 1\n" + + " },\n" + + " {\n" + + " \"node_index\": 2,\n" + + " \"leaf_value\": 2\n" + + " }\n" + + " ],\n" + + " \"target_type\": \"regression\"\n" + + " }\n" + + " }\n" + + "}"; + + public void testEnsembleSchemaDeserialization() throws IOException { + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ENSEMBLE_MODEL); + TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build(); + assertThat(definition.getTrainedModel().getClass(), equalTo(Ensemble.class)); + } + + public void testTreeSchemaDeserialization() throws IOException { + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, TREE_MODEL); + TrainedModelDefinition definition = TrainedModelDefinition.fromXContent(parser, false).build(); + assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class)); + } + @Override protected TrainedModelDefinition createTestInstance() { return createRandomBuilder().build(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 798007d0416..c03274132ef 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -108,25 +109,6 @@ public class EnsembleTests extends AbstractSerializingTestCase { return new NamedWriteableRegistry(entries); } - public void testEnsembleWithModelsThatHaveDifferentFeatureNames() { - List featureNames = Arrays.asList("foo", "bar", "baz", "farequote"); - ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { - Ensemble.builder().setFeatureNames(featureNames) - .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6))) - .build() - .validate(); - }); - assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models")); - - ex = expectThrows(ElasticsearchException.class, () -> { - Ensemble.builder().setFeatureNames(featureNames) - .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6))) - .build() - .validate(); - }); - assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models")); - } - public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() { List featureNames = Arrays.asList("foo", "bar"); int numberOfModels = 5; @@ -279,6 +261,17 @@ public class EnsembleTests extends AbstractSerializingTestCase { for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); } + + // This should handle missing values and take the default_left path + featureMap = new HashMap(2) {{ + put("foo", 0.3); + put("bar", null); + }}; + expected = Arrays.asList(0.6899744811, 0.3100255188); + probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + } } public void testClassificationInference() { @@ -336,6 +329,12 @@ public class EnsembleTests extends AbstractSerializingTestCase { featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + + featureMap = new HashMap(2) {{ + put("foo", 0.3); + put("bar", null); + }}; + assertEquals(0.0, ensemble.infer(featureMap), 0.00001); } public void testRegressionInference() { @@ -394,6 +393,12 @@ public class EnsembleTests extends AbstractSerializingTestCase { featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + + featureMap = new HashMap(2) {{ + put("foo", 0.3); + put("bar", null); + }}; + assertEquals(1.8, ensemble.infer(featureMap), 0.00001); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 1063041467e..b98d19b07ff 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -17,12 +17,14 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; @@ -118,19 +120,26 @@ public class TreeTests extends AbstractSerializingTestCase { // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.3, tree.infer(featureMap), 0.00001); + assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.1, tree.infer(featureMap), 0.00001); + assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.2, tree.infer(featureMap), 0.00001); + assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001)); + + // This should handle missing values and take the default_left path + featureMap = new HashMap(2) {{ + put("foo", 0.3); + put("bar", null); + }}; + assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); } public void testTreeClassificationProbability() { @@ -162,6 +171,13 @@ public class TreeTests extends AbstractSerializingTestCase { featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + + // This should handle missing values and take the default_left path + featureMap = new HashMap(2) {{ + put("foo", 0.3); + put("bar", null); + }}; + assertEquals(1.0, tree.infer(featureMap), 0.00001); } public void testTreeWithNullRoot() {