[ML] adds multi-class feature importance support (#53803) (#54024)

Adds multi-class feature importance calculation. 

Feature importance objects are now mapped as follows
(logistic) Regression:
```
{
   "feature_name": "feature_0",
   "importance": -1.3
}
```
Multi-class [class names are `foo`, `bar`, `baz`]
```
{ 
   “feature_name”: “feature_0”, 
   “importance”: 2.0, // sum(abs()) of class importances
   “foo”: 1.0, 
   “bar”: 0.5, 
   “baz”: -0.5 
},
```

For users to get the full benefit of aggregating and searching for feature importance, they should update their index mapping as follows (before turning this option on in their pipelines)
```
 "ml.inference.feature_importance": {
          "type": "nested",
          "dynamic": true,
          "properties": {
            "feature_name": {
              "type": "keyword"
            },
            "importance": {
              "type": "double"
            }
          }
        }
```
The mapping field name is as follows
`ml.<inference.target_field>.<inference.tag>.feature_importance`
if `inference.tag` is not provided in the processor definition, it is not part of the field path.
`inference.target_field` is defaulted to `ml.inference`.
//cc @lcawl ^ Where should we document this?

If this makes it in for 7.7, there shouldn't be any feature_importance at inference BWC worries as 7.7 is the first version to have it.
This commit is contained in:
Benjamin Trent 2020-03-23 18:49:07 -04:00 committed by GitHub
parent 5c96a7e210
commit 19af869243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 410 additions and 138 deletions

View File

@ -35,13 +35,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
String classificationLabel,
List<TopClassEntry> topClasses,
InferenceConfig config) {
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
}
public ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
InferenceConfig config) {
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
}
@ -49,7 +49,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
private ClassificationInferenceResults(double value,
String classificationLabel,
List<TopClassEntry> topClasses,
Map<String, Double> featureImportance,
List<FeatureImportance> featureImportance,
ClassificationConfig classificationConfig) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
@ -118,7 +118,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
public class FeatureImportance implements Writeable {
private final Map<String, Double> classImportance;
private final double importance;
private final String featureName;
private static final String IMPORTANCE = "importance";
private static final String FEATURE_NAME = "feature_name";
public static FeatureImportance forRegression(String featureName, double importance) {
return new FeatureImportance(featureName, importance, null);
}
public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
}
private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
this.featureName = Objects.requireNonNull(featureName);
this.importance = importance;
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
}
public FeatureImportance(StreamInput in) throws IOException {
this.featureName = in.readString();
this.importance = in.readDouble();
if (in.readBoolean()) {
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
} else {
this.classImportance = null;
}
}
public Map<String, Double> getClassImportance() {
return classImportance;
}
public double getImportance() {
return importance;
}
public String getFeatureName() {
return featureName;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.featureName);
out.writeDouble(this.importance);
out.writeBoolean(this.classImportance != null);
if (this.classImportance != null) {
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
}
}
public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME, featureName);
map.put(IMPORTANCE, importance);
if (classImportance != null) {
classImportance.forEach(map::put);
}
return map;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
FeatureImportance that = (FeatureImportance) object;
return Objects.equals(featureName, that.featureName)
&& Objects.equals(importance, that.importance)
&& Objects.equals(classImportance, that.classImportance);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportance);
}
}

View File

@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
public static final String NAME = "raw";
private final double[] value;
private final Map<String, Double> featureImportance;
private final Map<String, double[]> featureImportance;
public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
this.value = value;
this.featureImportance = featureImportance;
}
@ -29,7 +29,7 @@ public class RawInferenceResults implements InferenceResults {
return value;
}
public Map<String, Double> getFeatureImportance() {
public Map<String, double[]> getFeatureImportance() {
return featureImportance;
}

View File

@ -14,8 +14,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
public class RegressionInferenceResults extends SingleValueInferenceResults {
@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
private final String resultsField;
public RegressionInferenceResults(double value, InferenceConfig config) {
this(value, (RegressionConfig) config, Collections.emptyMap());
this(value, (RegressionConfig) config, Collections.emptyList());
}
public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
this(value, (RegressionConfig)config, featureImportance);
}
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
super(value,
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
regressionConfig.getNumTopFeatureImportanceValues()));
@ -70,7 +71,10 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
document.setFieldValue(parentResultField + "." + this.resultsField, value());
if (getFeatureImportance().size() > 0) {
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
.stream()
.map(FeatureImportance::toMap)
.collect(Collectors.toList()));
}
}

View File

@ -8,45 +8,46 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.List;
import java.util.stream.Collectors;
public abstract class SingleValueInferenceResults implements InferenceResults {
private final double value;
private final Map<String, Double> featureImportance;
private final List<FeatureImportance> featureImportance;
static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
return unsortedFeatureImportances.entrySet()
.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
return unsortedFeatureImportances;
}
return unsortedFeatureImportances.stream()
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
.limit(numTopFeatures)
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
.collect(Collectors.toList());
}
SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
this.featureImportance = in.readList(FeatureImportance::new);
} else {
this.featureImportance = Collections.emptyMap();
this.featureImportance = Collections.emptyList();
}
}
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
this.value = value;
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
}
public Double value() {
return value;
}
public Map<String, Double> getFeatureImportance() {
public List<FeatureImportance> getFeatureImportance() {
return featureImportance;
}
@ -58,7 +59,7 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
out.writeList(this.featureImportance);
}
}

View File

@ -8,12 +8,14 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -100,18 +102,46 @@ public final class InferenceHelpers {
return null;
}
public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, Double> featureImportances) {
public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
Map<String, double[]> featureImportances) {
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
return featureImportances;
}
Map<String, Double> originalFeatureImportance = new HashMap<>();
Map<String, double[]> originalFeatureImportance = new HashMap<>();
featureImportances.forEach((feature, importance) -> {
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
});
return originalFeatureImportance;
}
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
@Nullable List<String> classificationLabels) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
} else {
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
}
importances.add(FeatureImportance.forClassification(k, classImportance));
}
});
return importances;
}
public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
assert sumTo != null && inc != null && sumTo.length == inc.length;
for (int i = 0; i < inc.length; i++) {
sumTo[i] += inc[i];
}
return sumTo;
}
}

View File

@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* NOTE: Must be thread safe
* @param fields The fields inferring against
* @param featureDecoder A Map translating processed feature names to their original feature names
* @return A {@code Map<String, Double>} mapping each featureName to its importance
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
*/
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
default Version getMinimalCompatibilityVersion() {
return Version.V_7_6_0;

View File

@ -45,6 +45,8 @@ import java.util.OptionalDouble;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
@ -139,7 +141,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
double[][] inferenceResults = new double[this.models.size()][];
List<Map<String, Double>> featureInfluence = new ArrayList<>();
List<Map<String, double[]>> featureInfluence = new ArrayList<>();
int i = 0;
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
for (TrainedModel model : models) {
@ -152,7 +154,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
}
}
double[] processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, featureInfluence, config, featureDecoderMap);
return buildResults(processed,
decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
config);
}
@Override
@ -161,19 +165,19 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
}
private InferenceResults buildResults(double[] processedInferences,
List<Map<String, Double>> featureInfluence,
InferenceConfig config,
Map<String, String> featureDecoderMap) {
Map<String, double[]> featureInfluence,
InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
return new RawInferenceResults(
new double[] {outputAggregator.aggregate(processedInferences)},
featureInfluence);
}
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
transformFeatureImportance(featureInfluence, null));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
@ -186,7 +190,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return new ClassificationInferenceResults((double)topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
transformFeatureImportance(featureInfluence, classificationLabels),
config);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
@ -313,20 +317,23 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
}
Map<String, Double> featureImportance(Map<String, Object> fields) {
Map<String, double[]> featureImportance(Map<String, Object> fields) {
return featureImportance(fields, Collections.emptyMap());
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
Map<String, double[]> collapsed = mergeFeatureImportances(models.stream()
.map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
.collect(Collectors.toList()));
return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
return decodeFeatureImportances(featureDecoder, collapsed);
}
private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
private static Map<String, double[]> mergeFeatureImportances(List<Map<String, double[]>> featureImportances) {
return featureImportances.stream()
.collect(HashMap::new,
(a, b) -> b.forEach((k, v) -> a.merge(k, v, InferenceHelpers::sumDoubleArrays)),
Map::putAll);
}
public static Builder builder() {

View File

@ -142,7 +142,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
return new ClassificationInferenceResults(topClasses.v1(),
LANGUAGE_NAMES.get(topClasses.v1()),
topClasses.v2(),
Collections.emptyMap(),
Collections.emptyList(),
classificationConfig);
}
@ -170,7 +170,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
}

View File

@ -91,8 +91,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
private final List<String> classificationLabels;
private final CachedSupplier<Double> highestOrderCategory;
// populated lazily when feature importance is calculated
private double[] nodeEstimates;
private Integer maxDepth;
private Integer leafSize;
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
@ -137,7 +137,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
.collect(Collectors.toList());
Map<String, Double> featureImportance = config.requestingImportance() ?
Map<String, double[]> featureImportance = config.requestingImportance() ?
featureImportance(features, featureDecoderMap) :
Collections.emptyMap();
@ -149,7 +149,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return buildResult(node.getLeafValue(), featureImportance, config);
}
private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
private InferenceResults buildResult(double[] value, Map<String, double[]> featureImportance, InferenceConfig config) {
assert value != null && value.length > 0;
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
@ -166,10 +166,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return new ClassificationInferenceResults(topClasses.v1(),
classificationLabel(topClasses.v1(), classificationLabels),
topClasses.v2(),
featureImportance,
InferenceHelpers.transformFeatureImportance(featureImportance, classificationLabels),
config);
case REGRESSION:
return new RegressionInferenceResults(value[0], config, featureImportance);
return new RegressionInferenceResults(value[0],
config,
InferenceHelpers.transformFeatureImportance(featureImportance, null));
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
@ -283,7 +285,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
@Override
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
}
@ -293,9 +295,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return featureImportance(features, featureDecoder);
}
private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
calculateNodeEstimatesIfNeeded();
double[] featureImportance = new double[fieldValues.size()];
private Map<String, double[]> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
calculateDepthAndLeafValueSize();
double[][] featureImportance = new double[fieldValues.size()][leafSize];
for (int i = 0; i < fieldValues.size(); i++) {
featureImportance[i] = new double[leafSize];
}
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
for (int i = 0; i < arrSize; i++) {
@ -303,24 +308,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
double[] scale = new double[arrSize];
ShapPath initialPath = new ShapPath(elements, scale);
shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
}
private void calculateNodeEstimatesIfNeeded() {
if (this.nodeEstimates != null && this.maxDepth != null) {
private void calculateDepthAndLeafValueSize() {
if (this.maxDepth != null && this.leafSize != null) {
return;
}
synchronized (this) {
if (this.nodeEstimates != null && this.maxDepth != null) {
if (this.maxDepth != null && this.leafSize != null) {
return;
}
double[] estimates = new double[nodes.size()];
this.maxDepth = fillNodeEstimates(estimates, 0, 0);
this.nodeEstimates = estimates;
this.maxDepth = getDepth(0, 0);
}
}
@ -331,23 +334,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
* side first and then ported to the Java side.
*/
private void shapRecursive(List<Double> processedFeatures,
double[] nodeValues,
ShapPath parentSplitPath,
int nodeIndex,
double parentFractionZero,
double parentFractionOne,
int parentFeatureIndex,
double[] featureImportance,
double[][] featureImportance,
int nextIndex) {
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
TreeNode currNode = nodes.get(nodeIndex);
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
double leafValue = nodeValues[nodeIndex];
double[] leafValue = currNode.getLeafValue();
for (int i = 1; i < nextIndex; ++i) {
double scale = splitPath.sumUnwoundPath(i, nextIndex);
int inputColumnIndex = splitPath.featureIndex(i);
featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
for (int j = 0; j < leafValue.length; j++) {
featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
}
}
} else {
int hotIndex = currNode.compare(processedFeatures);
@ -365,41 +369,32 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
shapRecursive(processedFeatures, nodeValues, splitPath,
shapRecursive(processedFeatures, splitPath,
hotIndex, incomingFractionZero * hotFractionZero,
incomingFractionOne, splitFeature, featureImportance, nextIndex);
shapRecursive(processedFeatures, nodeValues, splitPath,
shapRecursive(processedFeatures, splitPath,
coldIndex, incomingFractionZero * coldFractionZero,
0.0, splitFeature, featureImportance, nextIndex);
}
}
/**
* This recursively populates the provided {@code double[]} with the node estimated values
* Get the depth of the tree and sets leafSize if it is null
*
* Used when calculating feature importance.
* @param nodeEstimates Array to update in place with the node estimated values
* @param nodeIndex Current node index
* @param depth Current depth
* @return The current max depth
*/
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
private int getDepth(int nodeIndex, int depth) {
TreeNode node = nodes.get(nodeIndex);
if (node.isLeaf()) {
// TODO multi-value????
nodeEstimates[nodeIndex] = node.getLeafValue()[0];
if (leafSize == null) {
this.leafSize = node.getLeafValue().length;
}
return 0;
}
int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
long divisor = leftWeight + rightWeight;
double averageValue = divisor == 0 ?
0.0 :
(leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
nodeEstimates[nodeIndex] = averageValue;
int depthLeft = getDepth(node.getLeftChild(), depth + 1);
int depthRight = getDepth(node.getRightChild(), depth + 1);
return Math.max(depthLeft, depthRight) + 1;
}

View File

@ -16,20 +16,30 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
public static ClassificationInferenceResults createRandomResults() {
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
FeatureImportanceTests::randomClassification :
FeatureImportanceTests::randomRegression;
return new ClassificationInferenceResults(randomDouble(),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null :
Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList()),
randomBoolean() ? null :
Stream.generate(featureImportanceCtor)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()),
ClassificationConfigTests.randomClassificationConfig());
}
@ -81,6 +91,40 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo"));
}
public void testWriteResultsWithImportance() {
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
FeatureImportanceTests::randomClassification :
FeatureImportanceTests::randomRegression;
List<FeatureImportance> importanceList = Stream.generate(featureImportanceCtor)
.limit(5)
.collect(Collectors.toList());
ClassificationInferenceResults result = new ClassificationInferenceResults(0.0,
"foo",
Collections.emptyList(),
importanceList,
new ClassificationConfig(0, "predicted_value", "top_classes", 3));
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
@SuppressWarnings("unchecked")
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
"result_field.feature_importance",
List.class);
assertThat(writtenImportance, hasSize(3));
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
for (int i = 0; i < 3; i++) {
Map<String, Object> objectMap = writtenImportance.get(i);
FeatureImportance importance = importanceList.get(i);
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
if (importance.getClassImportance() != null) {
importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v)));
}
}
}
@Override
protected ClassificationInferenceResults createTestInstance() {
return createRandomResults();

View File

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class FeatureImportanceTests extends AbstractWireSerializingTestCase<FeatureImportance> {
public static FeatureImportance createRandomInstance() {
return randomBoolean() ? randomClassification() : randomRegression();
}
static FeatureImportance randomRegression() {
return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
}
static FeatureImportance randomClassification() {
return FeatureImportance.forClassification(
randomAlphaOfLength(10),
Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(2, 10))
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
}
@Override
protected FeatureImportance createTestInstance() {
return createRandomInstance();
}
@Override
protected Writeable.Reader<FeatureImportance> instanceReader() {
return FeatureImportance::new;
}
}

View File

@ -22,7 +22,8 @@ public class RawInferenceResultsTests extends ESTestCase {
for (int i = 0; i < n; i++) {
results[i] = randomDouble();
}
return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
return new RawInferenceResults(results,
randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", new double[]{1.08}));
}
public void testEqualityAndHashcode() {
@ -31,7 +32,9 @@ public class RawInferenceResultsTests extends ESTestCase {
for (int i = 0; i < n; i++) {
results[i] = randomDouble();
}
Map<String, Double> importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08);
Map<String, double[]> importance = randomBoolean() ?
Collections.emptyMap() :
Collections.singletonMap("foo", new double[]{1.08, 42.0});
RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
assertThat(lft, equalTo(rgt));

View File

@ -8,19 +8,28 @@ package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase<RegressionInferenceResults> {
public static RegressionInferenceResults createRandomResults() {
return new RegressionInferenceResults(randomDouble(), RegressionConfigTests.randomRegressionConfig());
return new RegressionInferenceResults(randomDouble(),
RegressionConfigTests.randomRegressionConfig(),
randomBoolean() ? null :
Stream.generate(FeatureImportanceTests::randomRegression)
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList()));
}
public void testWriteResults() {
@ -31,6 +40,32 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
}
public void testWriteResultsWithImportance() {
List<FeatureImportance> importanceList = Stream.generate(FeatureImportanceTests::randomRegression)
.limit(5)
.collect(Collectors.toList());
RegressionInferenceResults result = new RegressionInferenceResults(0.3,
new RegressionConfig("predicted_value", 3),
importanceList);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
@SuppressWarnings("unchecked")
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
"result_field.feature_importance",
List.class);
assertThat(writtenImportance, hasSize(3));
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
for (int i = 0; i < 3; i++) {
Map<String, Object> objectMap = writtenImportance.get(i);
FeatureImportance importance = importanceList.get(i);
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
assertThat(objectMap.size(), equalTo(2));
}
}
@Override
protected RegressionInferenceResults createTestInstance() {
return createRandomResults();

View File

@ -684,45 +684,45 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.build();
Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
Map<String, double[]> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
assertThat(featureImportance.get("foo")[0], closeTo(0.0798679, eps));
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
assertThat(featureImportance.get("foo")[0], closeTo(1.80491886, eps));
assertThat(featureImportance.get("bar")[0], closeTo(-0.4355742, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
}

View File

@ -390,22 +390,22 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
Map<String, double[]> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
}
public void testMaxFeatureIndex() {

View File

@ -156,8 +156,10 @@ public class InferenceIngestIT extends ESRestTestCase {
String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"predicted_value\":1.0"));
assertThat(responseString, containsString("\"col2\":0.944"));
assertThat(responseString, containsString("\"col1\":0.19999"));
assertThat(responseString, containsString("\"feature_name\":\"col1\""));
assertThat(responseString, containsString("\"feature_name\":\"col2\""));
assertThat(responseString, containsString("\"importance\":0.944"));
assertThat(responseString, containsString("\"importance\":0.19999"));
String sourceWithMissingModel = "{\n" +
" \"pipeline\": {\n" +
@ -221,8 +223,10 @@ public class InferenceIngestIT extends ESRestTestCase {
Response response = client().performRequest(simulateRequest(source));
String responseString = EntityUtils.toString(response.getEntity());
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
assertThat(responseString, containsString("\"col2\":0.944"));
assertThat(responseString, containsString("\"col1\":0.19999"));
assertThat(responseString, containsString("\"feature_name\":\"col1\""));
assertThat(responseString, containsString("\"feature_name\":\"col2\""));
assertThat(responseString, containsString("\"importance\":0.944"));
assertThat(responseString, containsString("\"importance\":0.19999"));
}
public void testSimulateLangIdent() throws IOException {

View File

@ -10,6 +10,7 @@ import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@ -120,9 +121,9 @@ public class InferenceProcessorTests extends ESTestCase {
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
Map<String, Double> featureInfluence = new HashMap<>();
featureInfluence.put("feature_1", 1.13);
featureInfluence.put("feature_2", -42.0);
List<FeatureImportance> featureInfluence = new ArrayList<>();
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0,
@ -135,8 +136,10 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
}
@SuppressWarnings("unchecked")
@ -205,9 +208,9 @@ public class InferenceProcessorTests extends ESTestCase {
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
Map<String, Double> featureInfluence = new HashMap<>();
featureInfluence.put("feature_1", 1.13);
featureInfluence.put("feature_2", -42.0);
List<FeatureImportance> featureInfluence = new ArrayList<>();
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
@ -215,8 +218,10 @@ public class InferenceProcessorTests extends ESTestCase {
assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
}
public void testGenerateRequestWithEmptyMapping() {