From 3309817d186a91e92e1d55daf1a96cd585498729 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 16 Jun 2020 13:29:11 -0400 Subject: [PATCH] [ML] fixing tree inference ctor to allow target_type to be optional (#58132) (#58165) The tree trained model object will set its target_type to be regression by default. This updates the inference object to behave the same way. --- .../inference/TreeInferenceModel.java | 18 +++++++++++++----- .../inference/TreeInferenceModelTests.java | 9 +++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index f9a0a81700f..7cdf480e056 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; import org.apache.lucene.util.Accountable; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Numbers; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -60,12 +61,16 @@ public class TreeInferenceModel implements InferenceModel { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "tree_inference_model", true, - a -> new TreeInferenceModel((List)a[0], (List)a[1], TargetType.fromString((String)a[2]), (List)a[3])); + a -> new TreeInferenceModel( + (List)a[0], + (List)a[1], + a[2] == null ? null : TargetType.fromString((String)a[2]), + (List)a[3])); static { PARSER.declareStringArray(constructorArg(), FEATURE_NAMES); PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE); - PARSER.declareString(constructorArg(), TARGET_TYPE); + PARSER.declareString(optionalConstructorArg(), TARGET_TYPE); PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS); } @@ -82,13 +87,16 @@ public class TreeInferenceModel implements InferenceModel { private final int leafSize; private volatile boolean preparedForInference = false; - TreeInferenceModel(List featureNames, List nodes, TargetType targetType, List classificationLabels) { + TreeInferenceModel(List featureNames, + List nodes, + @Nullable TargetType targetType, + List classificationLabels) { this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]); if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) { throw new IllegalArgumentException("[tree_structure] must not be empty"); } this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new); - this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.targetType = targetType == null ? TargetType.REGRESSION : targetType; this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); this.highOrderCategory = maxLeafValue(); int leafSize = 1; @@ -357,7 +365,7 @@ public class TreeInferenceModel implements InferenceModel { return Math.max(depthLeft, depthRight) + 1; } - private static class NodeBuilder { + static class NodeBuilder { private static final ObjectParser PARSER = new ObjectParser<>( "tree_inference_model_node", diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java index e0d40848376..995bba94024 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java @@ -58,6 +58,15 @@ public class TreeInferenceModelTests extends ESTestCase { return new NamedXContentRegistry(namedXContent); } + public void testCtorWithNullTargetType() { + TreeInferenceModel treeInferenceModel = new TreeInferenceModel( + Collections.emptyList(), + Collections.singletonList(new TreeInferenceModel.NodeBuilder().setLeafValue(new double[]{1.0}).setNumberSamples(100L)), + null, + Collections.emptyList()); + assertThat(treeInferenceModel.targetType(), equalTo(TargetType.REGRESSION)); + } + public void testSerializationFromEnsemble() throws Exception { for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));