diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index e43fd4a9b56..72256b35e7f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; @@ -114,19 +115,9 @@ public class EnsembleInferenceModel implements InferenceModel { return targetType; } - private double[] getFeatures(Map fields) { - double[] features = new double[featureNames.length]; - int i = 0; - for (String featureName : featureNames) { - Double val = InferenceHelpers.toDouble(fields.get(featureName)); - features[i++] = val == null ? Double.NaN : val; - } - return features; - } - @Override public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { - return innerInfer(getFeatures(fields), config, featureDecoderMap); + return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap); } @Override @@ -142,9 +133,9 @@ public class EnsembleInferenceModel implements InferenceModel { if (preparedForInference == false) { throw ExceptionsHelper.serverError("model is not prepared for inference"); } - LOGGER.debug("Inference called with feature names [{}]", - featureNames == null ? "" : Strings.arrayToCommaDelimitedString(featureNames)); - assert featureNames != null && featureNames.length > 0; + LOGGER.debug( + () -> new ParameterizedMessage("Inference called with feature names [{}]", Strings.arrayToCommaDelimitedString(featureNames)) + ); double[][] inferenceResults = new double[this.models.size()][]; double[][] featureInfluence = new double[features.length][]; int i = 0; @@ -244,27 +235,28 @@ public class EnsembleInferenceModel implements InferenceModel { } @Override - public void rewriteFeatureIndices(Map newFeatureIndexMapping) { - LOGGER.debug("rewriting features {}", newFeatureIndexMapping); + public void rewriteFeatureIndices(final Map newFeatureIndexMapping) { + LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping)); if (preparedForInference) { return; } preparedForInference = true; + Map featureIndexMapping = new HashMap<>(); if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) { Set referencedFeatures = subModelFeatures(); - LOGGER.debug("detected submodel feature names {}", referencedFeatures); + LOGGER.debug(() -> new ParameterizedMessage("detected submodel feature names {}", referencedFeatures)); int newFeatureIndex = 0; - newFeatureIndexMapping = new HashMap<>(); + featureIndexMapping = new HashMap<>(); this.featureNames = new String[referencedFeatures.size()]; for (String featureName : referencedFeatures) { - newFeatureIndexMapping.put(featureName, newFeatureIndex); + featureIndexMapping.put(featureName, newFeatureIndex); this.featureNames[newFeatureIndex++] = featureName; } } else { this.featureNames = new String[0]; } for (InferenceModel model : models) { - model.rewriteFeatureIndices(newFeatureIndexMapping); + model.rewriteFeatureIndices(featureIndexMapping); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index 0e95b66dd15..d4d18b6849d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -102,6 +102,15 @@ public class InferenceDefinition { } } + @Override + public String toString() { + return "InferenceDefinition{" + + "trainedModel=" + trainedModel + + ", preProcessors=" + preProcessors + + ", decoderMap=" + decoderMap + + '}'; + } + public static Builder builder() { return new Builder(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java index 56bf9e30263..8df81e4868a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceModel.java @@ -10,12 +10,23 @@ import org.apache.lucene.util.Accountable; import org.elasticsearch.common.Nullable; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.util.Map; public interface InferenceModel extends Accountable { + static double[] extractFeatures(String[] featureNames, Map fields) { + double[] features = new double[featureNames.length]; + int i = 0; + for (String featureName : featureNames) { + Double val = InferenceHelpers.toDouble(fields.get(featureName)); + features[i++] = val == null ? Double.NaN : val; + } + return features; + } + /** * @return The feature names in their desired order */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index ac35f0cff4a..c4649503a49 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.Accountable; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Numbers; @@ -127,7 +128,7 @@ public class TreeInferenceModel implements InferenceModel { @Override public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { - return innerInfer(getFeatures(fields), config, featureDecoderMap); + return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap); } @Override @@ -135,16 +136,6 @@ public class TreeInferenceModel implements InferenceModel { return innerInfer(features, config, Collections.emptyMap()); } - private double[] getFeatures(Map fields) { - double[] features = new double[featureNames.length]; - int i = 0; - for (String featureName : featureNames) { - Double val = InferenceHelpers.toDouble(fields.get(featureName)); - features[i++] = val == null ? Double.NaN : val; - } - return features; - } - private InferenceResults innerInfer(double[] features, InferenceConfig config, Map featureDecoderMap) { if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( @@ -311,7 +302,7 @@ public class TreeInferenceModel implements InferenceModel { @Override public void rewriteFeatureIndices(Map newFeatureIndexMapping) { - LOGGER.debug("rewriting features {}", newFeatureIndexMapping); + LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping)); if (preparedForInference) { return; } @@ -353,7 +344,7 @@ public class TreeInferenceModel implements InferenceModel { if (node instanceof LeafNode) { LeafNode leafNode = (LeafNode) node; if (leafNode.leafValue.length > 1) { - return (double)leafNode.leafValue.length; + return leafNode.leafValue.length; } else { max = Math.max(leafNode.leafValue[0], max); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 1689778a9ed..11a80ef00e1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -165,6 +165,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array", maxFeatureIndex, FEATURE_NAMES.getPreferredName()); } + if (nodes.size() > 1) { + if (featureNames.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] is empty and the tree has > 1 nodes; num nodes [{}]. " + + "The model Must have features if tree is not a stump", + FEATURE_NAMES.getPreferredName(), nodes.size()); + } + } checkTargetType(); detectMissingNodes(); detectCycle(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 82d3aabdeab..fec079aa539 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -175,6 +175,35 @@ public class TreeTests extends AbstractSerializingTestCase { assertThat(ex.getMessage(), equalTo(msg)); } + public void testTreeWithEmptyFeaturesAndOneNode() { + // Shouldn't throw + Tree.builder() + .setRoot(TreeNode.builder(0).setLeafValue(10.0)) + .setFeatureNames(Collections.emptyList()) + .build() + .validate(); + } + + public void testTreeWithEmptyFeaturesAndThreeNodes() { + String msg = "[feature_names] is empty and the tree has > 1 nodes; num nodes [3]. " + + "The model Must have features if tree is not a stump"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Tree.builder() + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setThreshold(randomDouble())) + .addNode(TreeNode.builder(1) + .setLeafValue(randomDouble())) + .addNode(TreeNode.builder(2) + .setLeafValue(randomDouble())) + .setFeatureNames(Collections.emptyList()) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo(msg)); + } + public void testOperationsEstimations() { Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); assertThat(tree.estimatedNumOperations(), equalTo(7L));