The tree trained model object will set its target_type to be regression by default. This updates the inference object to behave the same way.
This commit is contained in:
parent
c6acc7c976
commit
3309817d18
|
@ -7,6 +7,7 @@
|
||||||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
||||||
|
|
||||||
import org.apache.lucene.util.Accountable;
|
import org.apache.lucene.util.Accountable;
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.Numbers;
|
import org.elasticsearch.common.Numbers;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
@ -60,12 +61,16 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser<>(
|
||||||
"tree_inference_model",
|
"tree_inference_model",
|
||||||
true,
|
true,
|
||||||
a -> new TreeInferenceModel((List<String>)a[0], (List<NodeBuilder>)a[1], TargetType.fromString((String)a[2]), (List<String>)a[3]));
|
a -> new TreeInferenceModel(
|
||||||
|
(List<String>)a[0],
|
||||||
|
(List<NodeBuilder>)a[1],
|
||||||
|
a[2] == null ? null : TargetType.fromString((String)a[2]),
|
||||||
|
(List<String>)a[3]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
|
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
|
||||||
PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE);
|
PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE);
|
||||||
PARSER.declareString(constructorArg(), TARGET_TYPE);
|
PARSER.declareString(optionalConstructorArg(), TARGET_TYPE);
|
||||||
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
|
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,13 +87,16 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
private final int leafSize;
|
private final int leafSize;
|
||||||
private volatile boolean preparedForInference = false;
|
private volatile boolean preparedForInference = false;
|
||||||
|
|
||||||
TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, TargetType targetType, List<String> classificationLabels) {
|
TreeInferenceModel(List<String> featureNames,
|
||||||
|
List<NodeBuilder> nodes,
|
||||||
|
@Nullable TargetType targetType,
|
||||||
|
List<String> classificationLabels) {
|
||||||
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
|
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
|
||||||
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
|
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
|
||||||
throw new IllegalArgumentException("[tree_structure] must not be empty");
|
throw new IllegalArgumentException("[tree_structure] must not be empty");
|
||||||
}
|
}
|
||||||
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
|
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
|
||||||
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
|
||||||
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
||||||
this.highOrderCategory = maxLeafValue();
|
this.highOrderCategory = maxLeafValue();
|
||||||
int leafSize = 1;
|
int leafSize = 1;
|
||||||
|
@ -357,7 +365,7 @@ public class TreeInferenceModel implements InferenceModel {
|
||||||
return Math.max(depthLeft, depthRight) + 1;
|
return Math.max(depthLeft, depthRight) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class NodeBuilder {
|
static class NodeBuilder {
|
||||||
|
|
||||||
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
|
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
|
||||||
"tree_inference_model_node",
|
"tree_inference_model_node",
|
||||||
|
|
|
@ -58,6 +58,15 @@ public class TreeInferenceModelTests extends ESTestCase {
|
||||||
return new NamedXContentRegistry(namedXContent);
|
return new NamedXContentRegistry(namedXContent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCtorWithNullTargetType() {
|
||||||
|
TreeInferenceModel treeInferenceModel = new TreeInferenceModel(
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.singletonList(new TreeInferenceModel.NodeBuilder().setLeafValue(new double[]{1.0}).setNumberSamples(100L)),
|
||||||
|
null,
|
||||||
|
Collections.emptyList());
|
||||||
|
assertThat(treeInferenceModel.targetType(), equalTo(TargetType.REGRESSION));
|
||||||
|
}
|
||||||
|
|
||||||
public void testSerializationFromEnsemble() throws Exception {
|
public void testSerializationFromEnsemble() throws Exception {
|
||||||
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
|
for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) {
|
||||||
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
|
Tree tree = TreeTests.createRandom(randomFrom(TargetType.values()));
|
||||||
|
|
Loading…
Reference in New Issue