The tree trained model object will set its target_type to be regression by default. This updates the inference object to behave the same way.
This commit is contained in:
parent
c6acc7c976
commit
3309817d18
|
@ -7,6 +7,7 @@
|
|||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Numbers;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
|
@ -60,12 +61,16 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"tree_inference_model",
|
||||
true,
|
||||
a -> new TreeInferenceModel((List<String>)a[0], (List<NodeBuilder>)a[1], TargetType.fromString((String)a[2]), (List<String>)a[3]));
|
||||
a -> new TreeInferenceModel(
|
||||
(List<String>)a[0],
|
||||
(List<NodeBuilder>)a[1],
|
||||
a[2] == null ? null : TargetType.fromString((String)a[2]),
|
||||
(List<String>)a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
|
||||
PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE);
|
||||
PARSER.declareString(constructorArg(), TARGET_TYPE);
|
||||
PARSER.declareString(optionalConstructorArg(), TARGET_TYPE);
|
||||
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
|
||||
}
|
||||
|
||||
|
@ -82,13 +87,16 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
private final int leafSize;
|
||||
private volatile boolean preparedForInference = false;
|
||||
|
||||
TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
TreeInferenceModel(List<String> featureNames,
|
||||
List<NodeBuilder> nodes,
|
||||
@Nullable TargetType targetType,
|
||||
List<String> classificationLabels) {
|
||||
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
|
||||
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
|
||||
throw new IllegalArgumentException("[tree_structure] must not be empty");
|
||||
}
|
||||
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
|
||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
||||
this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
|
||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||
this.highOrderCategory = maxLeafValue();
|
||||
int leafSize = 1;
|
||||
|
@ -357,7 +365,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return Math.max(depthLeft, depthRight) + 1;
|
||||
}
|
||||
|
||||
private static class NodeBuilder {
|
||||
static class NodeBuilder {
|
||||
|
||||
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
|
||||
"tree_inference_model_node",
|
||||
|
|
|
@ -58,6 +58,15 @@ public class TreeInferenceModelTests extends ESTestCase {
|
|||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
public void testCtorWithNullTargetType() {
|
||||
TreeInferenceModel treeInferenceModel = new TreeInferenceModel(
|
||||
Collections.emptyList(),
|
||||
Collections.singletonList(new TreeInferenceModel.NodeBuilder().setLeafValue(new double[]{1.0}).setNumberSamples(100L)),
|
||||
null,
|
||||
Collections.emptyList());
|
||||
assertThat(treeInferenceModel.targetType(), equalTo(TargetType.REGRESSION));
|
||||
}
|
||||
|
||||
public void testSerializationFromEnsemble() throws Exception {
|
||||
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
|
||||
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
|
||||
|
|
Loading…
Reference in New Issue