[Ml] Validate tree feature index is within range (#52514)

This changes the tree validation code to ensure no node in the tree has a
feature index that is beyond the bounds of the feature_names array.
Specifically this handles the situation where the C++ emits a tree containing
a single node and an empty feature_names list. This is valid tree used to
centre the data in the ensemble but the validation code would reject this
as feature_names is empty. This meant a broken workflow as you cannot GET
the model and PUT it back
This commit is contained in:
David Kyle 2020-02-19 14:41:43 +00:00 committed by GitHub
parent 43376c6e06
commit 7bbe5c8464
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 139 additions and 28 deletions

View File

@ -61,11 +61,11 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
}
public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
int numFeatures = featureNames.size();
int maxFeatureIndex = featureNames.size() -1;
Tree.Builder builder = Tree.builder();
builder.setFeatureNames(featureNames);
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
for (int i = 0; i < depth -1; i++) {
@ -76,7 +76,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}

View File

@ -253,8 +253,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
@Override
public void validate() {
if (featureNames.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
int maxFeatureIndex = maxFeatureIndex();
if (maxFeatureIndex >= featureNames.size()) {
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
}
checkTargetType();
detectMissingNodes();
@ -267,6 +269,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
}
/**
* The highest index of a feature used any of the nodes.
* If no nodes use a feature return -1. This can only happen
* if the tree contains a single leaf node.
*
* @return The max or -1
*/
int maxFeatureIndex() {
int maxFeatureIndex = -1;
for (TreeNode node : nodes) {
maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
}
return maxFeatureIndex;
}
private void checkTargetType() {
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
throw ExceptionsHelper.badRequestException(

View File

@ -29,6 +29,7 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
@ -72,10 +73,10 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
public static Tree buildRandomTree(List<String> featureNames, int depth) {
Tree.Builder builder = Tree.builder();
int numFeatures = featureNames.size() - 1;
int maxFeatureIndex = featureNames.size() - 1;
builder.setFeatureNames(featureNames);
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
for (int i = 0; i < depth -1; i++) {
@ -86,7 +87,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
@ -339,26 +340,83 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(ex.getMessage(), equalTo(msg));
}
public void testTreeWithEmptyFeatureNames() {
String msg = "[feature_names] must not be empty for tree model";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
public void testOperationsEstimations() {
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
assertThat(tree.estimatedNumOperations(), equalTo(7L));
}
public void testMaxFeatureIndex() {
int numFeatures = randomIntBetween(1, 15);
// We need a tree where every feature is used, choose a depth big enough to
// accommodate those non-leave nodes (leaf nodes don't have a feature index)
int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1;
List<String> featureNames = new ArrayList<>(numFeatures);
for (int i=0; i<numFeatures; i++) {
featureNames.add("feature" + i);
}
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
// build a tree using feature indices 0..numFeatures -1
int featureIndex = 0;
TreeNode.Builder node = builder.addJunction(0, featureIndex++, true, randomDouble());
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
for (int i = 0; i < depth -1; i++) {
List<Integer> nextNodes = new ArrayList<>();
for (int nodeId : childNodes) {
if (i == depth -2) {
builder.addLeaf(nodeId, randomDouble());
} else {
TreeNode.Builder childNode =
builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble());
nextNodes.add(childNode.getLeftChild());
nextNodes.add(childNode.getRightChild());
}
}
childNodes = nextNodes;
}
Tree tree = builder.build();
assertEquals(numFeatures, tree.maxFeatureIndex() +1);
}
public void testMaxFeatureIndexSingleNodeTree() {
Tree tree = Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build();
assertEquals(-1, tree.maxFeatureIndex());
}
public void testValidateGivenMissingFeatures() {
List<String> featureNames = Arrays.asList("foo", "bar", "baz");
// build a tree referencing a feature at index 3 which is not in the featureNames list
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
builder.addJunction(0, 0, true, randomDouble());
builder.addJunction(1, 1, true, randomDouble());
builder.addJunction(2, 3, true, randomDouble());
builder.addLeaf(3, randomDouble());
builder.addLeaf(4, randomDouble());
builder.addLeaf(5, randomDouble());
builder.addLeaf(6, randomDouble());
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate());
assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array"));
}
public void testValidateGivenTreeWithNoFeatures() {
Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
}
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}

View File

@ -137,7 +137,7 @@ integTest.runner {
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
'ml/inference_crud/Test put ensemble with empty models',
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
'ml/inference_crud/Test put model with empty input.field_names',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',

View File

@ -333,6 +333,40 @@ setup:
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
---
"Test put ensemble with single node and empty feature_names":
- do:
ml.put_trained_model:
model_id: "ensemble_tree_empty_feature_names"
body: >
{
"input": {
"field_names": "fieldy_mc_fieldname"
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": [],
"trained_models": [
{
"tree": {
"feature_names": [],
"tree_structure": [
{
"node_index": 0,
"decision_type": "lte",
"leaf_value": 12.0,
"default_left": true
}]
}
}
]
}
}
}
}
---
"Test put ensemble with empty models":
- do:
catch: /\[trained_models\] must not be empty/
@ -353,11 +387,11 @@ setup:
}
}
---
"Test put ensemble with tree where tree has empty feature-names":
"Test put ensemble with tree where tree has out of bounds feature_names index":
- do:
catch: /\[feature_names\] must not be empty/
catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/
ml.put_trained_model:
model_id: "ensemble_tree_missing_feature_names"
model_id: "ensemble_tree_out_of_bounds_feature_names_index"
body: >
{
"input": {
@ -374,7 +408,7 @@ setup:
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_feature": 1,
"split_gain": 12.0,
"threshold": 10.0,
"decision_type": "lte",