mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-23 13:26:02 +00:00
[Ml] Validate tree feature index is within range (#52514)
This changes the tree validation code to ensure no node in the tree has a feature index that is beyond the bounds of the feature_names array. Specifically this handles the situation where the C++ emits a tree containing a single node and an empty feature_names list. This is valid tree used to centre the data in the ensemble but the validation code would reject this as feature_names is empty. This meant a broken workflow as you cannot GET the model and PUT it back
This commit is contained in:
parent
43376c6e06
commit
7bbe5c8464
@ -61,11 +61,11 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
|
public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
|
||||||
int numFeatures = featureNames.size();
|
int maxFeatureIndex = featureNames.size() -1;
|
||||||
Tree.Builder builder = Tree.builder();
|
Tree.Builder builder = Tree.builder();
|
||||||
builder.setFeatureNames(featureNames);
|
builder.setFeatureNames(featureNames);
|
||||||
|
|
||||||
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
|
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
|
||||||
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
||||||
|
|
||||||
for (int i = 0; i < depth -1; i++) {
|
for (int i = 0; i < depth -1; i++) {
|
||||||
@ -76,7 +76,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||||||
builder.addLeaf(nodeId, randomDouble());
|
builder.addLeaf(nodeId, randomDouble());
|
||||||
} else {
|
} else {
|
||||||
TreeNode.Builder childNode =
|
TreeNode.Builder childNode =
|
||||||
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
|
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
|
||||||
nextNodes.add(childNode.getLeftChild());
|
nextNodes.add(childNode.getLeftChild());
|
||||||
nextNodes.add(childNode.getRightChild());
|
nextNodes.add(childNode.getRightChild());
|
||||||
}
|
}
|
||||||
|
@ -253,8 +253,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void validate() {
|
public void validate() {
|
||||||
if (featureNames.isEmpty()) {
|
int maxFeatureIndex = maxFeatureIndex();
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
|
if (maxFeatureIndex >= featureNames.size()) {
|
||||||
|
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
|
||||||
|
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
|
||||||
}
|
}
|
||||||
checkTargetType();
|
checkTargetType();
|
||||||
detectMissingNodes();
|
detectMissingNodes();
|
||||||
@ -267,6 +269,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||||||
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
|
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The highest index of a feature used any of the nodes.
|
||||||
|
* If no nodes use a feature return -1. This can only happen
|
||||||
|
* if the tree contains a single leaf node.
|
||||||
|
*
|
||||||
|
* @return The max or -1
|
||||||
|
*/
|
||||||
|
int maxFeatureIndex() {
|
||||||
|
int maxFeatureIndex = -1;
|
||||||
|
|
||||||
|
for (TreeNode node : nodes) {
|
||||||
|
maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxFeatureIndex;
|
||||||
|
}
|
||||||
|
|
||||||
private void checkTargetType() {
|
private void checkTargetType() {
|
||||||
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
|
if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
|
||||||
throw ExceptionsHelper.badRequestException(
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
@ -29,6 +29,7 @@ import java.util.stream.Collectors;
|
|||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.closeTo;
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
|
||||||
@ -72,10 +73,10 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||||||
|
|
||||||
public static Tree buildRandomTree(List<String> featureNames, int depth) {
|
public static Tree buildRandomTree(List<String> featureNames, int depth) {
|
||||||
Tree.Builder builder = Tree.builder();
|
Tree.Builder builder = Tree.builder();
|
||||||
int numFeatures = featureNames.size() - 1;
|
int maxFeatureIndex = featureNames.size() - 1;
|
||||||
builder.setFeatureNames(featureNames);
|
builder.setFeatureNames(featureNames);
|
||||||
|
|
||||||
TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
|
TreeNode.Builder node = builder.addJunction(0, randomInt(maxFeatureIndex), true, randomDouble());
|
||||||
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
||||||
|
|
||||||
for (int i = 0; i < depth -1; i++) {
|
for (int i = 0; i < depth -1; i++) {
|
||||||
@ -86,7 +87,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||||||
builder.addLeaf(nodeId, randomDouble());
|
builder.addLeaf(nodeId, randomDouble());
|
||||||
} else {
|
} else {
|
||||||
TreeNode.Builder childNode =
|
TreeNode.Builder childNode =
|
||||||
builder.addJunction(nodeId, randomInt(numFeatures), true, randomDouble());
|
builder.addJunction(nodeId, randomInt(maxFeatureIndex), true, randomDouble());
|
||||||
nextNodes.add(childNode.getLeftChild());
|
nextNodes.add(childNode.getLeftChild());
|
||||||
nextNodes.add(childNode.getRightChild());
|
nextNodes.add(childNode.getRightChild());
|
||||||
}
|
}
|
||||||
@ -339,26 +340,83 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||||||
assertThat(ex.getMessage(), equalTo(msg));
|
assertThat(ex.getMessage(), equalTo(msg));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTreeWithEmptyFeatureNames() {
|
|
||||||
String msg = "[feature_names] must not be empty for tree model";
|
|
||||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
|
||||||
Tree.builder()
|
|
||||||
.setRoot(TreeNode.builder(0)
|
|
||||||
.setLeftChild(1)
|
|
||||||
.setSplitFeature(1)
|
|
||||||
.setThreshold(randomDouble()))
|
|
||||||
.setFeatureNames(Collections.emptyList())
|
|
||||||
.build()
|
|
||||||
.validate();
|
|
||||||
});
|
|
||||||
assertThat(ex.getMessage(), equalTo(msg));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testOperationsEstimations() {
|
public void testOperationsEstimations() {
|
||||||
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
||||||
assertThat(tree.estimatedNumOperations(), equalTo(7L));
|
assertThat(tree.estimatedNumOperations(), equalTo(7L));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testMaxFeatureIndex() {
|
||||||
|
|
||||||
|
int numFeatures = randomIntBetween(1, 15);
|
||||||
|
// We need a tree where every feature is used, choose a depth big enough to
|
||||||
|
// accommodate those non-leave nodes (leaf nodes don't have a feature index)
|
||||||
|
int depth = (int) Math.ceil(Math.log(numFeatures +1) / Math.log(2)) + 1;
|
||||||
|
List<String> featureNames = new ArrayList<>(numFeatures);
|
||||||
|
for (int i=0; i<numFeatures; i++) {
|
||||||
|
featureNames.add("feature" + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
|
||||||
|
|
||||||
|
// build a tree using feature indices 0..numFeatures -1
|
||||||
|
int featureIndex = 0;
|
||||||
|
TreeNode.Builder node = builder.addJunction(0, featureIndex++, true, randomDouble());
|
||||||
|
List<Integer> childNodes = Arrays.asList(node.getLeftChild(), node.getRightChild());
|
||||||
|
|
||||||
|
for (int i = 0; i < depth -1; i++) {
|
||||||
|
List<Integer> nextNodes = new ArrayList<>();
|
||||||
|
for (int nodeId : childNodes) {
|
||||||
|
if (i == depth -2) {
|
||||||
|
builder.addLeaf(nodeId, randomDouble());
|
||||||
|
} else {
|
||||||
|
TreeNode.Builder childNode =
|
||||||
|
builder.addJunction(nodeId, featureIndex++ % numFeatures, true, randomDouble());
|
||||||
|
nextNodes.add(childNode.getLeftChild());
|
||||||
|
nextNodes.add(childNode.getRightChild());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
childNodes = nextNodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tree tree = builder.build();
|
||||||
|
|
||||||
|
assertEquals(numFeatures, tree.maxFeatureIndex() +1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMaxFeatureIndexSingleNodeTree() {
|
||||||
|
Tree tree = Tree.builder()
|
||||||
|
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
|
||||||
|
.setFeatureNames(Collections.emptyList())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assertEquals(-1, tree.maxFeatureIndex());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidateGivenMissingFeatures() {
|
||||||
|
List<String> featureNames = Arrays.asList("foo", "bar", "baz");
|
||||||
|
|
||||||
|
// build a tree referencing a feature at index 3 which is not in the featureNames list
|
||||||
|
Tree.Builder builder = Tree.builder().setFeatureNames(featureNames);
|
||||||
|
builder.addJunction(0, 0, true, randomDouble());
|
||||||
|
builder.addJunction(1, 1, true, randomDouble());
|
||||||
|
builder.addJunction(2, 3, true, randomDouble());
|
||||||
|
builder.addLeaf(3, randomDouble());
|
||||||
|
builder.addLeaf(4, randomDouble());
|
||||||
|
builder.addLeaf(5, randomDouble());
|
||||||
|
builder.addLeaf(6, randomDouble());
|
||||||
|
|
||||||
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> builder.build().validate());
|
||||||
|
assertThat(e.getDetailedMessage(), containsString("feature index [3] is out of bounds for the [feature_names] array"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidateGivenTreeWithNoFeatures() {
|
||||||
|
Tree.builder()
|
||||||
|
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
|
||||||
|
.setFeatureNames(Collections.emptyList())
|
||||||
|
.build()
|
||||||
|
.validate();
|
||||||
|
}
|
||||||
|
|
||||||
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
|
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
|
||||||
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
||||||
}
|
}
|
||||||
|
@ -137,7 +137,7 @@ integTest.runner {
|
|||||||
'ml/inference_crud/Test get given missing trained model',
|
'ml/inference_crud/Test get given missing trained model',
|
||||||
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
|
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
|
||||||
'ml/inference_crud/Test put ensemble with empty models',
|
'ml/inference_crud/Test put ensemble with empty models',
|
||||||
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
|
'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
|
||||||
'ml/inference_crud/Test put model with empty input.field_names',
|
'ml/inference_crud/Test put model with empty input.field_names',
|
||||||
'ml/inference_stats_crud/Test get stats given missing trained model',
|
'ml/inference_stats_crud/Test get stats given missing trained model',
|
||||||
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
|
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
|
||||||
|
@ -333,6 +333,40 @@ setup:
|
|||||||
- match: { count: 1 }
|
- match: { count: 1 }
|
||||||
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||||
---
|
---
|
||||||
|
"Test put ensemble with single node and empty feature_names":
|
||||||
|
|
||||||
|
- do:
|
||||||
|
ml.put_trained_model:
|
||||||
|
model_id: "ensemble_tree_empty_feature_names"
|
||||||
|
body: >
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"field_names": "fieldy_mc_fieldname"
|
||||||
|
},
|
||||||
|
"definition": {
|
||||||
|
"trained_model": {
|
||||||
|
"ensemble": {
|
||||||
|
"feature_names": [],
|
||||||
|
"trained_models": [
|
||||||
|
{
|
||||||
|
"tree": {
|
||||||
|
"feature_names": [],
|
||||||
|
"tree_structure": [
|
||||||
|
{
|
||||||
|
"node_index": 0,
|
||||||
|
"decision_type": "lte",
|
||||||
|
"leaf_value": 12.0,
|
||||||
|
"default_left": true
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
---
|
||||||
"Test put ensemble with empty models":
|
"Test put ensemble with empty models":
|
||||||
- do:
|
- do:
|
||||||
catch: /\[trained_models\] must not be empty/
|
catch: /\[trained_models\] must not be empty/
|
||||||
@ -353,11 +387,11 @@ setup:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
---
|
---
|
||||||
"Test put ensemble with tree where tree has empty feature-names":
|
"Test put ensemble with tree where tree has out of bounds feature_names index":
|
||||||
- do:
|
- do:
|
||||||
catch: /\[feature_names\] must not be empty/
|
catch: /feature index \[1\] is out of bounds for the \[feature_names\] array/
|
||||||
ml.put_trained_model:
|
ml.put_trained_model:
|
||||||
model_id: "ensemble_tree_missing_feature_names"
|
model_id: "ensemble_tree_out_of_bounds_feature_names_index"
|
||||||
body: >
|
body: >
|
||||||
{
|
{
|
||||||
"input": {
|
"input": {
|
||||||
@ -374,7 +408,7 @@ setup:
|
|||||||
"tree_structure": [
|
"tree_structure": [
|
||||||
{
|
{
|
||||||
"node_index": 0,
|
"node_index": 0,
|
||||||
"split_feature": 0,
|
"split_feature": 1,
|
||||||
"split_gain": 12.0,
|
"split_gain": 12.0,
|
||||||
"threshold": 10.0,
|
"threshold": 10.0,
|
||||||
"decision_type": "lte",
|
"decision_type": "lte",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user