mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-22 12:56:53 +00:00
[Ml] Validate tree feature index is within range (#52514)
This changes the tree validation code to ensure no node in the tree has a feature index that is beyond the bounds of the feature_names array. Specifically this handles the situation where the C++ emits a tree containing a single node and an empty feature_names list. This is valid tree used to centre the data in the ensemble but the validation code would reject this as feature_names is empty. This meant a broken workflow as you cannot GET the model and PUT it back
This commit is contained in:
parent
43376c6e06
commit
7bbe5c8464
@ -61,11 +61,11 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
||||
}
|
||||
|
||||
public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
|
||||
int numFeatures = featureNames.size();
|
||||
int maxFeatureIndex = featureNames.size() -1;
|
||||
Tree.Builder builder = Tree.builder();
|
||||
builder.setFeatureNames(featureNames);
|
||||
|
||||
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
|
||||
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
|
||||
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
||||
|
||||
for (int i = 0; i < depth -1; i++) {
|
||||
@ -76,7 +76,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
||||
builder.addLeaf(nodeId, randomDouble());
|
||||
} else {
|
||||
TreeNode.Builder childNode =
|
||||
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
|
||||
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
|
||||
nextNodes.add(childNode.getLeftChild());
|
||||
nextNodes.add(childNode.getRightChild());
|
||||
}
|
||||
|
@ -253,8 +253,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
|
||||
@Override
|
||||
public void validate() {
|
||||
if (featureNames.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
|
||||
int maxFeatureIndex = maxFeatureIndex();
|
||||
if (maxFeatureIndex >= featureNames.size()) {
|
||||
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
|
||||
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
|
||||
}
|
||||
checkTargetType();
|
||||
detectMissingNodes();
|
||||
@ -267,6 +269,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* The highest index of a feature used any of the nodes.
|
||||
* If no nodes use a feature return -1. This can only happen
|
||||
* if the tree contains a single leaf node.
|
||||
*
|
||||
* @return The max or -1
|
||||
*/
|
||||
int maxFeatureIndex() {
|
||||
int maxFeatureIndex = -1;
|
||||
|
||||
for (TreeNode node : nodes) {
|
||||
maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
|
||||
}
|
||||
|
||||
return maxFeatureIndex;
|
||||
}
|
||||
|
||||
private void checkTargetType() {
|
||||
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
|
@ -29,6 +29,7 @@ import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
||||
@ -72,10 +73,10 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
|
||||
public static Tree buildRandomTree(List<String> featureNames, int depth) {
|
||||
Tree.Builder builder = Tree.builder();
|
||||
int numFeatures = featureNames.size() - 1;
|
||||
int maxFeatureIndex = featureNames.size() - 1;
|
||||
builder.setFeatureNames(featureNames);
|
||||
|
||||
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
|
||||
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
|
||||
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
||||
|
||||
for (int i = 0; i < depth -1; i++) {
|
||||
@ -86,7 +87,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
builder.addLeaf(nodeId, randomDouble());
|
||||
} else {
|
||||
TreeNode.Builder childNode =
|
||||
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
|
||||
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
|
||||
nextNodes.add(childNode.getLeftChild());
|
||||
nextNodes.add(childNode.getRightChild());
|
||||
}
|
||||
@ -339,26 +340,83 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testTreeWithEmptyFeatureNames() {
|
||||
String msg = "[feature_names] must not be empty for tree model";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setSplitFeature(1)
|
||||
.setThreshold(randomDouble()))
|
||||
.setFeatureNames(Collections.emptyList())
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testOperationsEstimations() {
|
||||
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
||||
assertThat(tree.estimatedNumOperations(), equalTo(7L));
|
||||
}
|
||||
|
||||
public void testMaxFeatureIndex() {
|
||||
|
||||
int numFeatures = randomIntBetween(1, 15);
|
||||
// We need a tree where every feature is used, choose a depth big enough to
|
||||
// accommodate those non-leave nodes (leaf nodes don't have a feature index)
|
||||
int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1;
|
||||
List<String> featureNames = new ArrayList<>(numFeatures);
|
||||
for (int i=0; i<numFeatures; i++) {
|
||||
featureNames.add("feature" + i);
|
||||
}
|
||||
|
||||
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
|
||||
|
||||
// build a tree using feature indices 0..numFeatures -1
|
||||
int featureIndex = 0;
|
||||
TreeNode.Builder node = builder.addJunction(0, featureIndex++, true, randomDouble());
|
||||
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
||||
|
||||
for (int i = 0; i < depth -1; i++) {
|
||||
List<Integer> nextNodes = new ArrayList<>();
|
||||
for (int nodeId : childNodes) {
|
||||
if (i == depth -2) {
|
||||
builder.addLeaf(nodeId, randomDouble());
|
||||
} else {
|
||||
TreeNode.Builder childNode =
|
||||
builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble());
|
||||
nextNodes.add(childNode.getLeftChild());
|
||||
nextNodes.add(childNode.getRightChild());
|
||||
}
|
||||
}
|
||||
childNodes = nextNodes;
|
||||
}
|
||||
|
||||
Tree tree = builder.build();
|
||||
|
||||
assertEquals(numFeatures, tree.maxFeatureIndex() +1);
|
||||
}
|
||||
|
||||
public void testMaxFeatureIndexSingleNodeTree() {
|
||||
Tree tree = Tree.builder()
|
||||
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
|
||||
.setFeatureNames(Collections.emptyList())
|
||||
.build();
|
||||
|
||||
assertEquals(-1, tree.maxFeatureIndex());
|
||||
}
|
||||
|
||||
public void testValidateGivenMissingFeatures() {
|
||||
List<String> featureNames = Arrays.asList("foo", "bar", "baz");
|
||||
|
||||
// build a tree referencing a feature at index 3 which is not in the featureNames list
|
||||
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
|
||||
builder.addJunction(0, 0, true, randomDouble());
|
||||
builder.addJunction(1, 1, true, randomDouble());
|
||||
builder.addJunction(2, 3, true, randomDouble());
|
||||
builder.addLeaf(3, randomDouble());
|
||||
builder.addLeaf(4, randomDouble());
|
||||
builder.addLeaf(5, randomDouble());
|
||||
builder.addLeaf(6, randomDouble());
|
||||
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate());
|
||||
assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array"));
|
||||
}
|
||||
|
||||
public void testValidateGivenTreeWithNoFeatures() {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
|
||||
.setFeatureNames(Collections.emptyList())
|
||||
.build()
|
||||
.validate();
|
||||
}
|
||||
|
||||
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
|
||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||
}
|
||||
|
@ -137,7 +137,7 @@ integTest.runner {
|
||||
'ml/inference_crud/Test get given missing trained model',
|
||||
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
|
||||
'ml/inference_crud/Test put ensemble with empty models',
|
||||
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
|
||||
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
|
||||
'ml/inference_crud/Test put model with empty input.field_names',
|
||||
'ml/inference_stats_crud/Test get stats given missing trained model',
|
||||
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
|
||||
|
@ -333,6 +333,40 @@ setup:
|
||||
- match: { count: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||
---
|
||||
"Test put ensemble with single node and empty feature_names":
|
||||
|
||||
- do:
|
||||
ml.put_trained_model:
|
||||
model_id: "ensemble_tree_empty_feature_names"
|
||||
body: >
|
||||
{
|
||||
"input": {
|
||||
"field_names": "fieldy_mc_fieldname"
|
||||
},
|
||||
"definition": {
|
||||
"trained_model": {
|
||||
"ensemble": {
|
||||
"feature_names": [],
|
||||
"trained_models": [
|
||||
{
|
||||
"tree": {
|
||||
"feature_names": [],
|
||||
"tree_structure": [
|
||||
{
|
||||
"node_index": 0,
|
||||
"decision_type": "lte",
|
||||
"leaf_value": 12.0,
|
||||
"default_left": true
|
||||
}]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put ensemble with empty models":
|
||||
- do:
|
||||
catch: /\[trained_models\] must not be empty/
|
||||
@ -353,11 +387,11 @@ setup:
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test put ensemble with tree where tree has empty feature-names":
|
||||
"Test put ensemble with tree where tree has out of bounds feature_names index":
|
||||
- do:
|
||||
catch: /\[feature_names\] must not be empty/
|
||||
catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/
|
||||
ml.put_trained_model:
|
||||
model_id: "ensemble_tree_missing_feature_names"
|
||||
model_id: "ensemble_tree_out_of_bounds_feature_names_index"
|
||||
body: >
|
||||
{
|
||||
"input": {
|
||||
@ -374,7 +408,7 @@ setup:
|
||||
"tree_structure": [
|
||||
{
|
||||
"node_index": 0,
|
||||
"split_feature": 0,
|
||||
"split_feature": 1,
|
||||
"split_gain": 12.0,
|
||||
"threshold": 10.0,
|
||||
"decision_type": "lte",
|
||||
|
Loading…
x
Reference in New Issue
Block a user