When a tree model is provided, it is possible that it is a stump. Meaning, it only has one node with no splits This implies that the tree has no features. In this case, having zero feature_names is appropriate. In any other case, this should be considered a validation failure. This commit adds the validation if there is more than 1 node, that the feature_names in the model are non-empty. closes #60759
This commit is contained in:
parent
6710104673
commit
1b9dc0172a
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
|||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Strings;
|
||||
|
@ -114,19 +115,9 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
return targetType;
|
||||
}
|
||||
|
||||
private double[] getFeatures(Map<String, Object> fields) {
|
||||
double[] features = new double[featureNames.length];
|
||||
int i = 0;
|
||||
for (String featureName : featureNames) {
|
||||
Double val = InferenceHelpers.toDouble(fields.get(featureName));
|
||||
features[i++] = val == null ? Double.NaN : val;
|
||||
}
|
||||
return features;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
return innerInfer(getFeatures(fields), config, featureDecoderMap);
|
||||
return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -142,9 +133,9 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
if (preparedForInference == false) {
|
||||
throw ExceptionsHelper.serverError("model is not prepared for inference");
|
||||
}
|
||||
LOGGER.debug("Inference called with feature names [{}]",
|
||||
featureNames == null ? "<null>" : Strings.arrayToCommaDelimitedString(featureNames));
|
||||
assert featureNames != null && featureNames.length > 0;
|
||||
LOGGER.debug(
|
||||
() -> new ParameterizedMessage("Inference called with feature names [{}]", Strings.arrayToCommaDelimitedString(featureNames))
|
||||
);
|
||||
double[][] inferenceResults = new double[this.models.size()][];
|
||||
double[][] featureInfluence = new double[features.length][];
|
||||
int i = 0;
|
||||
|
@ -244,27 +235,28 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
||||
public void rewriteFeatureIndices(final Map<String, Integer> newFeatureIndexMapping) {
|
||||
LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping));
|
||||
if (preparedForInference) {
|
||||
return;
|
||||
}
|
||||
preparedForInference = true;
|
||||
Map<String, Integer> featureIndexMapping = new HashMap<>();
|
||||
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
|
||||
Set<String> referencedFeatures = subModelFeatures();
|
||||
LOGGER.debug("detected submodel feature names {}", referencedFeatures);
|
||||
LOGGER.debug(() -> new ParameterizedMessage("detected submodel feature names {}", referencedFeatures));
|
||||
int newFeatureIndex = 0;
|
||||
newFeatureIndexMapping = new HashMap<>();
|
||||
featureIndexMapping = new HashMap<>();
|
||||
this.featureNames = new String[referencedFeatures.size()];
|
||||
for (String featureName : referencedFeatures) {
|
||||
newFeatureIndexMapping.put(featureName, newFeatureIndex);
|
||||
featureIndexMapping.put(featureName, newFeatureIndex);
|
||||
this.featureNames[newFeatureIndex++] = featureName;
|
||||
}
|
||||
} else {
|
||||
this.featureNames = new String[0];
|
||||
}
|
||||
for (InferenceModel model : models) {
|
||||
model.rewriteFeatureIndices(newFeatureIndexMapping);
|
||||
model.rewriteFeatureIndices(featureIndexMapping);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -102,6 +102,15 @@ public class InferenceDefinition {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "InferenceDefinition{" +
|
||||
"trainedModel=" + trainedModel +
|
||||
", preProcessors=" + preProcessors +
|
||||
", decoderMap=" + decoderMap +
|
||||
'}';
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
|
|
@ -10,12 +10,23 @@ import org.apache.lucene.util.Accountable;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface InferenceModel extends Accountable {
|
||||
|
||||
static double[] extractFeatures(String[] featureNames, Map<String, Object> fields) {
|
||||
double[] features = new double[featureNames.length];
|
||||
int i = 0;
|
||||
for (String featureName : featureNames) {
|
||||
Double val = InferenceHelpers.toDouble(fields.get(featureName));
|
||||
features[i++] = val == null ? Double.NaN : val;
|
||||
}
|
||||
return features;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The feature names in their desired order
|
||||
*/
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
|||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Numbers;
|
||||
|
@ -127,7 +128,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
return innerInfer(getFeatures(fields), config, featureDecoderMap);
|
||||
return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -135,16 +136,6 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return innerInfer(features, config, Collections.emptyMap());
|
||||
}
|
||||
|
||||
private double[] getFeatures(Map<String, Object> fields) {
|
||||
double[] features = new double[featureNames.length];
|
||||
int i = 0;
|
||||
for (String featureName : featureNames) {
|
||||
Double val = InferenceHelpers.toDouble(fields.get(featureName));
|
||||
features[i++] = val == null ? Double.NaN : val;
|
||||
}
|
||||
return features;
|
||||
}
|
||||
|
||||
private InferenceResults innerInfer(double[] features, InferenceConfig config, Map<String, String> featureDecoderMap) {
|
||||
if (config.isTargetTypeSupported(targetType) == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
|
@ -311,7 +302,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
|
||||
@Override
|
||||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
||||
LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
||||
LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping));
|
||||
if (preparedForInference) {
|
||||
return;
|
||||
}
|
||||
|
@ -353,7 +344,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
if (node instanceof LeafNode) {
|
||||
LeafNode leafNode = (LeafNode) node;
|
||||
if (leafNode.leafValue.length > 1) {
|
||||
return (double)leafNode.leafValue.length;
|
||||
return leafNode.leafValue.length;
|
||||
} else {
|
||||
max = Math.max(leafNode.leafValue[0], max);
|
||||
}
|
||||
|
|
|
@ -165,6 +165,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|||
throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array",
|
||||
maxFeatureIndex, FEATURE_NAMES.getPreferredName());
|
||||
}
|
||||
if (nodes.size() > 1) {
|
||||
if (featureNames.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] is empty and the tree has > 1 nodes; num nodes [{}]. " +
|
||||
"The model Must have features if tree is not a stump",
|
||||
FEATURE_NAMES.getPreferredName(), nodes.size());
|
||||
}
|
||||
}
|
||||
checkTargetType();
|
||||
detectMissingNodes();
|
||||
detectCycle();
|
||||
|
|
|
@ -175,6 +175,35 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
|||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testTreeWithEmptyFeaturesAndOneNode() {
|
||||
// Shouldn't throw
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0).setLeafValue(10.0))
|
||||
.setFeatureNames(Collections.emptyList())
|
||||
.build()
|
||||
.validate();
|
||||
}
|
||||
|
||||
public void testTreeWithEmptyFeaturesAndThreeNodes() {
|
||||
String msg = "[feature_names] is empty and the tree has > 1 nodes; num nodes [3]. " +
|
||||
"The model Must have features if tree is not a stump";
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
|
||||
Tree.builder()
|
||||
.setRoot(TreeNode.builder(0)
|
||||
.setLeftChild(1)
|
||||
.setRightChild(2)
|
||||
.setThreshold(randomDouble()))
|
||||
.addNode(TreeNode.builder(1)
|
||||
.setLeafValue(randomDouble()))
|
||||
.addNode(TreeNode.builder(2)
|
||||
.setLeafValue(randomDouble()))
|
||||
.setFeatureNames(Collections.emptyList())
|
||||
.build()
|
||||
.validate();
|
||||
});
|
||||
assertThat(ex.getMessage(), equalTo(msg));
|
||||
}
|
||||
|
||||
public void testOperationsEstimations() {
|
||||
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
|
||||
assertThat(tree.estimatedNumOperations(), equalTo(7L));
|
||||
|
|
Loading…
Reference in New Issue