[ML][Inference] Adds validations for model PUT (#51376) (#51409)

Adds validations making sure that

* `input.field_names` is not empty
* `ensemble.trained_models` is not empty
* `tree.feature_names` is not empty

closes https://github.com/elastic/elasticsearch/issues/51354
This commit is contained in:
Benjamin Trent 2020-01-24 09:29:12 -05:00 committed by GitHub
parent d177747f66
commit fc994d9ce1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 1 deletions

View File

@ -32,6 +32,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
@ -70,7 +71,7 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.buildRandomTree(Collections.emptyList(), 6, targetType),
.setTrainedModel(randomFrom(TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 6, targetType),
EnsembleTests.createRandom(targetType)));
}

View File

@ -535,6 +535,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
break;
}
}
if (input != null && input.getFieldNames().isEmpty()) {
validationException = addValidationError("[input.field_names] must not be empty", validationException);
}
if (forCreation) {
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);

View File

@ -250,6 +250,10 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
@Override
public void validate() {
if (this.models.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must not be empty", TRAINED_MODELS.getPreferredName());
}
if (outputAggregator.compatibleWith(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"aggregate_output [{}] is not compatible with target_type [{}]",

View File

@ -253,6 +253,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
@Override
public void validate() {
if (featureNames.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
}
checkTargetType();
detectMissingNodes();
detectCycle();

View File

@ -26,6 +26,7 @@ import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -202,6 +203,14 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
assertThat(ex.getMessage(), equalTo(msg));
}
public void testEnsembleWithEmptyModels() {
List<String> featureNames = Arrays.asList("foo", "bar");
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder().setTrainedModels(Collections.emptyList()).setFeatureNames(featureNames).build().validate();
});
assertThat(ex.getMessage(), equalTo("[trained_models] must not be empty"));
}
public void testClassificationProbability() {
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree1 = Tree.builder()

View File

@ -339,6 +339,21 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(ex.getMessage(), equalTo(msg));
}
public void testTreeWithEmptyFeatureNames() {
String msg = "[feature_names] must not be empty for tree model";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
public void testOperationsEstimations() {
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
assertThat(tree.estimatedNumOperations(), equalTo(7L));

View File

@ -136,6 +136,9 @@ integTest.runner {
'ml/inference_crud/Test delete with missing model',
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
'ml/inference_crud/Test put ensemble with empty models',
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
'ml/inference_crud/Test put model with empty input.field_names',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
'ml/jobs_crud/Test cannot create job with existing categorizer state document',

View File

@ -206,3 +206,97 @@ setup:
allow_no_match: false
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
---
"Test put ensemble with empty models":
- do:
catch: /\[trained_models\] must not be empty/
ml.put_trained_model:
model_id: "missing_model_ensembles"
body: >
{
"input": {
"field_names": "fieldy_mc_fieldname"
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": [],
"trained_models": []
}
}
}
}
---
"Test put ensemble with tree where tree has empty feature-names":
- do:
catch: /\[feature_names\] must not be empty/
ml.put_trained_model:
model_id: "ensemble_tree_missing_feature_names"
body: >
{
"input": {
"field_names": "fieldy_mc_fieldname"
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": [],
"trained_models": [
{
"tree": {
"feature_names": [],
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_gain": 12.0,
"threshold": 10.0,
"decision_type": "lte",
"default_left": true,
"left_child": 1,
"right_child": 2
}]
}
}
]
}
}
}
}
---
"Test put model with empty input.field_names":
- do:
catch: /\[input\.field_names\] must not be empty/
ml.put_trained_model:
model_id: "missing_model_ensembles"
body: >
{
"input": {
"field_names": []
},
"definition": {
"trained_model": {
"ensemble": {
"feature_names": [],
"trained_models": [
{
"tree": {
"feature_names": [],
"tree_structure": [
{
"node_index": 0,
"split_feature": 0,
"split_gain": 12.0,
"threshold": 10.0,
"decision_type": "lte",
"default_left": true,
"left_child": 1,
"right_child": 2
}]
}
}
]
}
}
}
}