[ML] adding feature_name and node size validation for tree models (#62096) (#62161)

When a tree model is provided, it is possible that it is a stump.
Meaning, it only has one node with no splits
This implies that the tree has no features. In this case,
having zero feature_names is appropriate. In any other case,
this should be considered a validation failure.

This commit adds the validation if there is more than 1 node,
that the feature_names in the model are non-empty.

closes #60759
This commit is contained in:
Benjamin Trent 2020-09-09 08:50:25 -04:00 committed by GitHub
parent 6710104673
commit 1b9dc0172a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 33 deletions

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
@ -114,19 +115,9 @@ public class EnsembleInferenceModel implements InferenceModel {
return targetType;
}
private double[] getFeatures(Map<String, Object> fields) {
double[] features = new double[featureNames.length];
int i = 0;
for (String featureName : featureNames) {
Double val = InferenceHelpers.toDouble(fields.get(featureName));
features[i++] = val == null ? Double.NaN : val;
}
return features;
}
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
return innerInfer(getFeatures(fields), config, featureDecoderMap);
return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap);
}
@Override
@ -142,9 +133,9 @@ public class EnsembleInferenceModel implements InferenceModel {
if (preparedForInference == false) {
throw ExceptionsHelper.serverError("model is not prepared for inference");
}
LOGGER.debug("Inference called with feature names [{}]",
featureNames == null ? "<null>" : Strings.arrayToCommaDelimitedString(featureNames));
assert featureNames != null && featureNames.length > 0;
LOGGER.debug(
() -> new ParameterizedMessage("Inference called with feature names [{}]", Strings.arrayToCommaDelimitedString(featureNames))
);
double[][] inferenceResults = new double[this.models.size()][];
double[][] featureInfluence = new double[features.length][];
int i = 0;
@ -244,27 +235,28 @@ public class EnsembleInferenceModel implements InferenceModel {
}
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
public void rewriteFeatureIndices(final Map<String, Integer> newFeatureIndexMapping) {
LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping));
if (preparedForInference) {
return;
}
preparedForInference = true;
Map<String, Integer> featureIndexMapping = new HashMap<>();
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
Set<String> referencedFeatures = subModelFeatures();
LOGGER.debug("detected submodel feature names {}", referencedFeatures);
LOGGER.debug(() -> new ParameterizedMessage("detected submodel feature names {}", referencedFeatures));
int newFeatureIndex = 0;
newFeatureIndexMapping = new HashMap<>();
featureIndexMapping = new HashMap<>();
this.featureNames = new String[referencedFeatures.size()];
for (String featureName : referencedFeatures) {
newFeatureIndexMapping.put(featureName, newFeatureIndex);
featureIndexMapping.put(featureName, newFeatureIndex);
this.featureNames[newFeatureIndex++] = featureName;
}
} else {
this.featureNames = new String[0];
}
for (InferenceModel model : models) {
model.rewriteFeatureIndices(newFeatureIndexMapping);
model.rewriteFeatureIndices(featureIndexMapping);
}
}

View File

@ -102,6 +102,15 @@ public class InferenceDefinition {
}
}
@Override
public String toString() {
return "InferenceDefinition{" +
"trainedModel=" + trainedModel +
", preProcessors=" + preProcessors +
", decoderMap=" + decoderMap +
'}';
}
public static Builder builder() {
return new Builder();
}

View File

@ -10,12 +10,23 @@ import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.util.Map;
public interface InferenceModel extends Accountable {
static double[] extractFeatures(String[] featureNames, Map<String, Object> fields) {
double[] features = new double[featureNames.length];
int i = 0;
for (String featureName : featureNames) {
Double val = InferenceHelpers.toDouble(fields.get(featureName));
features[i++] = val == null ? Double.NaN : val;
}
return features;
}
/**
* @return The feature names in their desired order
*/

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Numbers;
@ -127,7 +128,7 @@ public class TreeInferenceModel implements InferenceModel {
@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
return innerInfer(getFeatures(fields), config, featureDecoderMap);
return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap);
}
@Override
@ -135,16 +136,6 @@ public class TreeInferenceModel implements InferenceModel {
return innerInfer(features, config, Collections.emptyMap());
}
private double[] getFeatures(Map<String, Object> fields) {
double[] features = new double[featureNames.length];
int i = 0;
for (String featureName : featureNames) {
Double val = InferenceHelpers.toDouble(fields.get(featureName));
features[i++] = val == null ? Double.NaN : val;
}
return features;
}
private InferenceResults innerInfer(double[] features, InferenceConfig config, Map<String, String> featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
@ -311,7 +302,7 @@ public class TreeInferenceModel implements InferenceModel {
@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping));
if (preparedForInference) {
return;
}
@ -353,7 +344,7 @@ public class TreeInferenceModel implements InferenceModel {
if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
if (leafNode.leafValue.length > 1) {
return (double)leafNode.leafValue.length;
return leafNode.leafValue.length;
} else {
max = Math.max(leafNode.leafValue[0], max);
}

View File

@ -165,6 +165,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
}
if (nodes.size() > 1) {
if (featureNames.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] is empty and the tree has > 1 nodes; num nodes [{}]. " +
"The model Must have features if tree is not a stump",
FEATURE_NAMES.getPreferredName(), nodes.size());
}
}
checkTargetType();
detectMissingNodes();
detectCycle();

View File

@ -175,6 +175,35 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(ex.getMessage(), equalTo(msg));
}
public void testTreeWithEmptyFeaturesAndOneNode() {
// Shouldn't throw
Tree.builder()
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
}
public void testTreeWithEmptyFeaturesAndThreeNodes() {
String msg = "[feature_names] is empty and the tree has > 1 nodes; num nodes [3]. " +
"The model Must have features if tree is not a stump";
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Tree.builder()
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setThreshold(randomDouble()))
.addNode(TreeNode.builder(1)
.setLeafValue(randomDouble()))
.addNode(TreeNode.builder(2)
.setLeafValue(randomDouble()))
.setFeatureNames(Collections.emptyList())
.build()
.validate();
});
assertThat(ex.getMessage(), equalTo(msg));
}
public void testOperationsEstimations() {
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
assertThat(tree.estimatedNumOperations(), equalTo(7L));