[ML] fixing tree inference ctor to allow target_type to be optional (#58132) (#58165)

The tree trained model object will set its target_type to be regression by default.

This updates the inference object to behave the same way.
This commit is contained in:
Benjamin Trent 2020-06-16 13:29:11 -04:00 committed by GitHub
parent c6acc7c976
commit 3309817d18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 5 deletions

View File

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Numbers; import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -60,12 +61,16 @@ public class TreeInferenceModel implements InferenceModel {
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>( private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
"tree_inference_model", "tree_inference_model",
true, true,
a -> new TreeInferenceModel((List<String>)a[0], (List<NodeBuilder>)a[1], TargetType.fromString((String)a[2]), (List<String>)a[3])); a -> new TreeInferenceModel(
(List<String>)a[0],
(List<NodeBuilder>)a[1],
a[2] == null ? null : TargetType.fromString((String)a[2]),
(List<String>)a[3]));
static { static {
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES); PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE); PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE);
PARSER.declareString(constructorArg(), TARGET_TYPE); PARSER.declareString(optionalConstructorArg(), TARGET_TYPE);
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS); PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
} }
@ -82,13 +87,16 @@ public class TreeInferenceModel implements InferenceModel {
private final int leafSize; private final int leafSize;
private volatile boolean preparedForInference = false; private volatile boolean preparedForInference = false;
TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, TargetType targetType, List<String> classificationLabels) { TreeInferenceModel(List<String> featureNames,
List<NodeBuilder> nodes,
@Nullable TargetType targetType,
List<String> classificationLabels) {
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]); this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) { if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
throw new IllegalArgumentException("[tree_structure] must not be empty"); throw new IllegalArgumentException("[tree_structure] must not be empty");
} }
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new); this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highOrderCategory = maxLeafValue(); this.highOrderCategory = maxLeafValue();
int leafSize = 1; int leafSize = 1;
@ -357,7 +365,7 @@ public class TreeInferenceModel implements InferenceModel {
return Math.max(depthLeft, depthRight) + 1; return Math.max(depthLeft, depthRight) + 1;
} }
private static class NodeBuilder { static class NodeBuilder {
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>( private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
"tree_inference_model_node", "tree_inference_model_node",

View File

@ -58,6 +58,15 @@ public class TreeInferenceModelTests extends ESTestCase {
return new NamedXContentRegistry(namedXContent); return new NamedXContentRegistry(namedXContent);
} }
public void testCtorWithNullTargetType() {
TreeInferenceModel treeInferenceModel = new TreeInferenceModel(
Collections.emptyList(),
Collections.singletonList(new TreeInferenceModel.NodeBuilder().setLeafValue(new double[]{1.0}).setNumberSamples(100L)),
null,
Collections.emptyList());
assertThat(treeInferenceModel.targetType(), equalTo(TargetType.REGRESSION));
}
public void testSerializationFromEnsemble() throws Exception { public void testSerializationFromEnsemble() throws Exception {
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values())); Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));