mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-26 06:46:10 +00:00
Adds multi-class feature importance calculation. Feature importance objects are now mapped as follows (logistic) Regression: ``` { "feature_name": "feature_0", "importance": -1.3 } ``` Multi-class [class names are `foo`, `bar`, `baz`] ``` { “feature_name”: “feature_0”, “importance”: 2.0, // sum(abs()) of class importances “foo”: 1.0, “bar”: 0.5, “baz”: -0.5 }, ``` For users to get the full benefit of aggregating and searching for feature importance, they should update their index mapping as follows (before turning this option on in their pipelines) ``` "ml.inference.feature_importance": { "type": "nested", "dynamic": true, "properties": { "feature_name": { "type": "keyword" }, "importance": { "type": "double" } } } ``` The mapping field name is as follows `ml.<inference.target_field>.<inference.tag>.feature_importance` if `inference.tag` is not provided in the processor definition, it is not part of the field path. `inference.target_field` is defaulted to `ml.inference`. //cc @lcawl ^ Where should we document this? If this makes it in for 7.7, there shouldn't be any feature_importance at inference BWC worries as 7.7 is the first version to have it.
This commit is contained in:
parent
5c96a7e210
commit
19af869243
@ -35,13 +35,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
InferenceConfig config) {
|
||||
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
|
||||
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
|
||||
}
|
||||
|
||||
public ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
Map<String, Double> featureImportance,
|
||||
List<FeatureImportance> featureImportance,
|
||||
InferenceConfig config) {
|
||||
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
|
||||
}
|
||||
@ -49,7 +49,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
||||
private ClassificationInferenceResults(double value,
|
||||
String classificationLabel,
|
||||
List<TopClassEntry> topClasses,
|
||||
Map<String, Double> featureImportance,
|
||||
List<FeatureImportance> featureImportance,
|
||||
ClassificationConfig classificationConfig) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
@ -118,7 +118,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
|
||||
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
|
||||
}
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
|
||||
.stream()
|
||||
.map(FeatureImportance::toMap)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class FeatureImportance implements Writeable {
|
||||
|
||||
private final Map<String, Double> classImportance;
|
||||
private final double importance;
|
||||
private final String featureName;
|
||||
private static final String IMPORTANCE = "importance";
|
||||
private static final String FEATURE_NAME = "feature_name";
|
||||
|
||||
public static FeatureImportance forRegression(String featureName, double importance) {
|
||||
return new FeatureImportance(featureName, importance, null);
|
||||
}
|
||||
|
||||
public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
|
||||
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
|
||||
}
|
||||
|
||||
private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
|
||||
this.featureName = Objects.requireNonNull(featureName);
|
||||
this.importance = importance;
|
||||
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
|
||||
}
|
||||
|
||||
public FeatureImportance(StreamInput in) throws IOException {
|
||||
this.featureName = in.readString();
|
||||
this.importance = in.readDouble();
|
||||
if (in.readBoolean()) {
|
||||
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
||||
} else {
|
||||
this.classImportance = null;
|
||||
}
|
||||
}
|
||||
|
||||
public Map<String, Double> getClassImportance() {
|
||||
return classImportance;
|
||||
}
|
||||
|
||||
public double getImportance() {
|
||||
return importance;
|
||||
}
|
||||
|
||||
public String getFeatureName() {
|
||||
return featureName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(this.featureName);
|
||||
out.writeDouble(this.importance);
|
||||
out.writeBoolean(this.classImportance != null);
|
||||
if (this.classImportance != null) {
|
||||
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
}
|
||||
}
|
||||
|
||||
public Map<String, Object> toMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(FEATURE_NAME, featureName);
|
||||
map.put(IMPORTANCE, importance);
|
||||
if (classImportance != null) {
|
||||
classImportance.forEach(map::put);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (object == this) { return true; }
|
||||
if (object == null || getClass() != object.getClass()) { return false; }
|
||||
FeatureImportance that = (FeatureImportance) object;
|
||||
return Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(importance, that.importance)
|
||||
&& Objects.equals(classImportance, that.classImportance);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportance);
|
||||
}
|
||||
|
||||
}
|
@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
|
||||
public static final String NAME = "raw";
|
||||
|
||||
private final double[] value;
|
||||
private final Map<String, Double> featureImportance;
|
||||
private final Map<String, double[]> featureImportance;
|
||||
|
||||
public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
|
||||
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
|
||||
this.value = value;
|
||||
this.featureImportance = featureImportance;
|
||||
}
|
||||
@ -29,7 +29,7 @@ public class RawInferenceResults implements InferenceResults {
|
||||
return value;
|
||||
}
|
||||
|
||||
public Map<String, Double> getFeatureImportance() {
|
||||
public Map<String, double[]> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
|
@ -14,8 +14,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
|
||||
@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
private final String resultsField;
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config) {
|
||||
this(value, (RegressionConfig) config, Collections.emptyMap());
|
||||
this(value, (RegressionConfig) config, Collections.emptyList());
|
||||
}
|
||||
|
||||
public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
|
||||
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
|
||||
this(value, (RegressionConfig)config, featureImportance);
|
||||
}
|
||||
|
||||
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
|
||||
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
|
||||
super(value,
|
||||
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
|
||||
regressionConfig.getNumTopFeatureImportanceValues()));
|
||||
@ -70,7 +71,10 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
|
||||
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
|
||||
document.setFieldValue(parentResultField + "." + this.resultsField, value());
|
||||
if (getFeatureImportance().size() > 0) {
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
|
||||
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
|
||||
.stream()
|
||||
.map(FeatureImportance::toMap)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,45 +8,46 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
|
||||
private final double value;
|
||||
private final Map<String, Double> featureImportance;
|
||||
private final List<FeatureImportance> featureImportance;
|
||||
|
||||
static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
|
||||
return unsortedFeatureImportances.entrySet()
|
||||
.stream()
|
||||
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
|
||||
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
|
||||
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
|
||||
return unsortedFeatureImportances;
|
||||
}
|
||||
return unsortedFeatureImportances.stream()
|
||||
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
|
||||
.limit(numTopFeatures)
|
||||
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(StreamInput in) throws IOException {
|
||||
value = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
||||
this.featureImportance = in.readList(FeatureImportance::new);
|
||||
} else {
|
||||
this.featureImportance = Collections.emptyMap();
|
||||
this.featureImportance = Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
|
||||
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
|
||||
this.value = value;
|
||||
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
|
||||
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
|
||||
}
|
||||
|
||||
public Double value() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public Map<String, Double> getFeatureImportance() {
|
||||
public List<FeatureImportance> getFeatureImportance() {
|
||||
return featureImportance;
|
||||
}
|
||||
|
||||
@ -58,7 +59,7 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(value);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
|
||||
out.writeList(this.featureImportance);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,12 +8,14 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
@ -100,18 +102,46 @@ public final class InferenceHelpers {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
|
||||
Map<String, Double> featureImportances) {
|
||||
public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
|
||||
Map<String, double[]> featureImportances) {
|
||||
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
|
||||
return featureImportances;
|
||||
}
|
||||
|
||||
Map<String, Double> originalFeatureImportance = new HashMap<>();
|
||||
Map<String, double[]> originalFeatureImportance = new HashMap<>();
|
||||
featureImportances.forEach((feature, importance) -> {
|
||||
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
|
||||
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
|
||||
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
|
||||
});
|
||||
|
||||
return originalFeatureImportance;
|
||||
}
|
||||
|
||||
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
featureImportance.forEach((k, v) -> {
|
||||
// This indicates regression, or logistic regression
|
||||
// If the length > 1, we assume multi-class classification.
|
||||
if (v.length == 1) {
|
||||
importances.add(FeatureImportance.forRegression(k, v[0]));
|
||||
} else {
|
||||
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
|
||||
// If the classificationLabels exist, their length must match leaf_value length
|
||||
assert classificationLabels == null || classificationLabels.size() == v.length;
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
|
||||
}
|
||||
importances.add(FeatureImportance.forClassification(k, classImportance));
|
||||
}
|
||||
});
|
||||
return importances;
|
||||
}
|
||||
|
||||
public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
|
||||
assert sumTo != null && inc != null && sumTo.length == inc.length;
|
||||
for (int i = 0; i < inc.length; i++) {
|
||||
sumTo[i] += inc[i];
|
||||
}
|
||||
return sumTo;
|
||||
}
|
||||
}
|
||||
|
@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
|
||||
* NOTE: Must be thread safe
|
||||
* @param fields The fields inferring against
|
||||
* @param featureDecoder A Map translating processed feature names to their original feature names
|
||||
* @return A {@code Map<String, Double>} mapping each featureName to its importance
|
||||
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
|
||||
*/
|
||||
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
|
||||
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
|
||||
|
||||
default Version getMinimalCompatibilityVersion() {
|
||||
return Version.V_7_6_0;
|
||||
|
@ -45,6 +45,8 @@ import java.util.OptionalDouble;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
|
||||
|
||||
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
||||
|
||||
@ -139,7 +141,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
|
||||
}
|
||||
double[][] inferenceResults = new double[this.models.size()][];
|
||||
List<Map<String, Double>> featureInfluence = new ArrayList<>();
|
||||
List<Map<String, double[]>> featureInfluence = new ArrayList<>();
|
||||
int i = 0;
|
||||
NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
|
||||
for (TrainedModel model : models) {
|
||||
@ -152,7 +154,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
}
|
||||
}
|
||||
double[] processed = outputAggregator.processValues(inferenceResults);
|
||||
return buildResults(processed, featureInfluence, config, featureDecoderMap);
|
||||
return buildResults(processed,
|
||||
decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
|
||||
config);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -161,19 +165,19 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
}
|
||||
|
||||
private InferenceResults buildResults(double[] processedInferences,
|
||||
List<Map<String, Double>> featureInfluence,
|
||||
InferenceConfig config,
|
||||
Map<String, String> featureDecoderMap) {
|
||||
Map<String, double[]> featureInfluence,
|
||||
InferenceConfig config) {
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
|
||||
return new RawInferenceResults(
|
||||
new double[] {outputAggregator.aggregate(processedInferences)},
|
||||
featureInfluence);
|
||||
}
|
||||
switch(targetType) {
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
config,
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
|
||||
transformFeatureImportance(featureInfluence, null));
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
||||
@ -186,7 +190,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
return new ClassificationInferenceResults((double)topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
|
||||
transformFeatureImportance(featureInfluence, classificationLabels),
|
||||
config);
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
|
||||
@ -313,20 +317,23 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
|
||||
return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
|
||||
}
|
||||
|
||||
Map<String, Double> featureImportance(Map<String, Object> fields) {
|
||||
Map<String, double[]> featureImportance(Map<String, Object> fields) {
|
||||
return featureImportance(fields, Collections.emptyMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
|
||||
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
Map<String, double[]> collapsed = mergeFeatureImportances(models.stream()
|
||||
.map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
|
||||
.collect(Collectors.toList()));
|
||||
return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
|
||||
return decodeFeatureImportances(featureDecoder, collapsed);
|
||||
}
|
||||
|
||||
private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
|
||||
return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
|
||||
private static Map<String, double[]> mergeFeatureImportances(List<Map<String, double[]>> featureImportances) {
|
||||
return featureImportances.stream()
|
||||
.collect(HashMap::new,
|
||||
(a, b) -> b.forEach((k, v) -> a.merge(k, v, InferenceHelpers::sumDoubleArrays)),
|
||||
Map::putAll);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
|
@ -142,7 +142,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
LANGUAGE_NAMES.get(topClasses.v1()),
|
||||
topClasses.v2(),
|
||||
Collections.emptyMap(),
|
||||
Collections.emptyList(),
|
||||
classificationConfig);
|
||||
}
|
||||
|
||||
@ -170,7 +170,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
|
||||
}
|
||||
|
||||
|
@ -91,8 +91,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
private final List<String> classificationLabels;
|
||||
private final CachedSupplier<Double> highestOrderCategory;
|
||||
// populated lazily when feature importance is calculated
|
||||
private double[] nodeEstimates;
|
||||
private Integer maxDepth;
|
||||
private Integer leafSize;
|
||||
|
||||
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
||||
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
||||
@ -137,7 +137,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Map<String, Double> featureImportance = config.requestingImportance() ?
|
||||
Map<String, double[]> featureImportance = config.requestingImportance() ?
|
||||
featureImportance(features, featureDecoderMap) :
|
||||
Collections.emptyMap();
|
||||
|
||||
@ -149,7 +149,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
return buildResult(node.getLeafValue(), featureImportance, config);
|
||||
}
|
||||
|
||||
private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
|
||||
private InferenceResults buildResult(double[] value, Map<String, double[]> featureImportance, InferenceConfig config) {
|
||||
assert value != null && value.length > 0;
|
||||
// Indicates that the config is useless and the caller just wants the raw value
|
||||
if (config instanceof NullInferenceConfig) {
|
||||
@ -166,10 +166,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
return new ClassificationInferenceResults(topClasses.v1(),
|
||||
classificationLabel(topClasses.v1(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
featureImportance,
|
||||
InferenceHelpers.transformFeatureImportance(featureImportance, classificationLabels),
|
||||
config);
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value[0], config, featureImportance);
|
||||
return new RegressionInferenceResults(value[0],
|
||||
config,
|
||||
InferenceHelpers.transformFeatureImportance(featureImportance, null));
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
||||
}
|
||||
@ -283,7 +285,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
||||
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
|
||||
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
|
||||
}
|
||||
@ -293,9 +295,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
return featureImportance(features, featureDecoder);
|
||||
}
|
||||
|
||||
private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
|
||||
calculateNodeEstimatesIfNeeded();
|
||||
double[] featureImportance = new double[fieldValues.size()];
|
||||
private Map<String, double[]> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
|
||||
calculateDepthAndLeafValueSize();
|
||||
double[][] featureImportance = new double[fieldValues.size()][leafSize];
|
||||
for (int i = 0; i < fieldValues.size(); i++) {
|
||||
featureImportance[i] = new double[leafSize];
|
||||
}
|
||||
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
|
||||
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
|
||||
for (int i = 0; i < arrSize; i++) {
|
||||
@ -303,24 +308,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
}
|
||||
double[] scale = new double[arrSize];
|
||||
ShapPath initialPath = new ShapPath(elements, scale);
|
||||
shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
|
||||
shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
|
||||
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
|
||||
IntStream.range(0, featureImportance.length)
|
||||
.boxed()
|
||||
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
|
||||
}
|
||||
|
||||
private void calculateNodeEstimatesIfNeeded() {
|
||||
if (this.nodeEstimates != null && this.maxDepth != null) {
|
||||
private void calculateDepthAndLeafValueSize() {
|
||||
if (this.maxDepth != null && this.leafSize != null) {
|
||||
return;
|
||||
}
|
||||
synchronized (this) {
|
||||
if (this.nodeEstimates != null && this.maxDepth != null) {
|
||||
if (this.maxDepth != null && this.leafSize != null) {
|
||||
return;
|
||||
}
|
||||
double[] estimates = new double[nodes.size()];
|
||||
this.maxDepth = fillNodeEstimates(estimates, 0, 0);
|
||||
this.nodeEstimates = estimates;
|
||||
this.maxDepth = getDepth(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -331,23 +334,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
* side first and then ported to the Java side.
|
||||
*/
|
||||
private void shapRecursive(List<Double> processedFeatures,
|
||||
double[] nodeValues,
|
||||
ShapPath parentSplitPath,
|
||||
int nodeIndex,
|
||||
double parentFractionZero,
|
||||
double parentFractionOne,
|
||||
int parentFeatureIndex,
|
||||
double[] featureImportance,
|
||||
double[][] featureImportance,
|
||||
int nextIndex) {
|
||||
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
|
||||
TreeNode currNode = nodes.get(nodeIndex);
|
||||
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
|
||||
if (currNode.isLeaf()) {
|
||||
double leafValue = nodeValues[nodeIndex];
|
||||
double[] leafValue = currNode.getLeafValue();
|
||||
for (int i = 1; i < nextIndex; ++i) {
|
||||
double scale = splitPath.sumUnwoundPath(i, nextIndex);
|
||||
int inputColumnIndex = splitPath.featureIndex(i);
|
||||
featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
|
||||
double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
|
||||
for (int j = 0; j < leafValue.length; j++) {
|
||||
featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int hotIndex = currNode.compare(processedFeatures);
|
||||
@ -365,41 +369,32 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
||||
|
||||
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
|
||||
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
|
||||
shapRecursive(processedFeatures, nodeValues, splitPath,
|
||||
shapRecursive(processedFeatures, splitPath,
|
||||
hotIndex, incomingFractionZero * hotFractionZero,
|
||||
incomingFractionOne, splitFeature, featureImportance, nextIndex);
|
||||
shapRecursive(processedFeatures, nodeValues, splitPath,
|
||||
shapRecursive(processedFeatures, splitPath,
|
||||
coldIndex, incomingFractionZero * coldFractionZero,
|
||||
0.0, splitFeature, featureImportance, nextIndex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This recursively populates the provided {@code double[]} with the node estimated values
|
||||
* Get the depth of the tree and sets leafSize if it is null
|
||||
*
|
||||
* Used when calculating feature importance.
|
||||
* @param nodeEstimates Array to update in place with the node estimated values
|
||||
* @param nodeIndex Current node index
|
||||
* @param depth Current depth
|
||||
* @return The current max depth
|
||||
*/
|
||||
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
|
||||
private int getDepth(int nodeIndex, int depth) {
|
||||
TreeNode node = nodes.get(nodeIndex);
|
||||
if (node.isLeaf()) {
|
||||
// TODO multi-value????
|
||||
nodeEstimates[nodeIndex] = node.getLeafValue()[0];
|
||||
if (leafSize == null) {
|
||||
this.leafSize = node.getLeafValue().length;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
|
||||
int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
|
||||
long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
|
||||
long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
|
||||
long divisor = leftWeight + rightWeight;
|
||||
double averageValue = divisor == 0 ?
|
||||
0.0 :
|
||||
(leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
|
||||
nodeEstimates[nodeIndex] = averageValue;
|
||||
int depthLeft = getDepth(node.getLeftChild(), depth + 1);
|
||||
int depthRight = getDepth(node.getRightChild(), depth + 1);
|
||||
return Math.max(depthLeft, depthRight) + 1;
|
||||
}
|
||||
|
||||
|
@ -16,20 +16,30 @@ import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
|
||||
|
||||
public static ClassificationInferenceResults createRandomResults() {
|
||||
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
|
||||
FeatureImportanceTests::randomClassification :
|
||||
FeatureImportanceTests::randomRegression;
|
||||
|
||||
return new ClassificationInferenceResults(randomDouble(),
|
||||
randomBoolean() ? null : randomAlphaOfLength(10),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
|
||||
.limit(randomIntBetween(0, 10))
|
||||
.collect(Collectors.toList()),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(featureImportanceCtor)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()),
|
||||
ClassificationConfigTests.randomClassificationConfig());
|
||||
}
|
||||
|
||||
@ -81,6 +91,40 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
|
||||
assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo"));
|
||||
}
|
||||
|
||||
public void testWriteResultsWithImportance() {
|
||||
Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
|
||||
FeatureImportanceTests::randomClassification :
|
||||
FeatureImportanceTests::randomRegression;
|
||||
|
||||
List<FeatureImportance> importanceList = Stream.generate(featureImportanceCtor)
|
||||
.limit(5)
|
||||
.collect(Collectors.toList());
|
||||
ClassificationInferenceResults result = new ClassificationInferenceResults(0.0,
|
||||
"foo",
|
||||
Collections.emptyList(),
|
||||
importanceList,
|
||||
new ClassificationConfig(0, "predicted_value", "top_classes", 3));
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
|
||||
"result_field.feature_importance",
|
||||
List.class);
|
||||
assertThat(writtenImportance, hasSize(3));
|
||||
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> objectMap = writtenImportance.get(i);
|
||||
FeatureImportance importance = importanceList.get(i);
|
||||
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
|
||||
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
|
||||
if (importance.getClassImportance() != null) {
|
||||
importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClassificationInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
|
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class FeatureImportanceTests extends AbstractWireSerializingTestCase<FeatureImportance> {
|
||||
|
||||
public static FeatureImportance createRandomInstance() {
|
||||
return randomBoolean() ? randomClassification() : randomRegression();
|
||||
}
|
||||
|
||||
static FeatureImportance randomRegression() {
|
||||
return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
|
||||
}
|
||||
|
||||
static FeatureImportance randomClassification() {
|
||||
return FeatureImportance.forClassification(
|
||||
randomAlphaOfLength(10),
|
||||
Stream.generate(() -> randomAlphaOfLength(10))
|
||||
.limit(randomLongBetween(2, 10))
|
||||
.collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FeatureImportance createTestInstance() {
|
||||
return createRandomInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<FeatureImportance> instanceReader() {
|
||||
return FeatureImportance::new;
|
||||
}
|
||||
}
|
@ -22,7 +22,8 @@ public class RawInferenceResultsTests extends ESTestCase {
|
||||
for (int i = 0; i < n; i++) {
|
||||
results[i] = randomDouble();
|
||||
}
|
||||
return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
|
||||
return new RawInferenceResults(results,
|
||||
randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", new double[]{1.08}));
|
||||
}
|
||||
|
||||
public void testEqualityAndHashcode() {
|
||||
@ -31,7 +32,9 @@ public class RawInferenceResultsTests extends ESTestCase {
|
||||
for (int i = 0; i < n; i++) {
|
||||
results[i] = randomDouble();
|
||||
}
|
||||
Map<String, Double> importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08);
|
||||
Map<String, double[]> importance = randomBoolean() ?
|
||||
Collections.emptyMap() :
|
||||
Collections.singletonMap("foo", new double[]{1.08, 42.0});
|
||||
RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
|
||||
RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
|
||||
assertThat(lft, equalTo(rgt));
|
||||
|
@ -8,19 +8,28 @@ package org.elasticsearch.xpack.core.ml.inference.results;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
|
||||
public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase<RegressionInferenceResults> {
|
||||
|
||||
public static RegressionInferenceResults createRandomResults() {
|
||||
return new RegressionInferenceResults(randomDouble(), RegressionConfigTests.randomRegressionConfig());
|
||||
return new RegressionInferenceResults(randomDouble(),
|
||||
RegressionConfigTests.randomRegressionConfig(),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(FeatureImportanceTests::randomRegression)
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
public void testWriteResults() {
|
||||
@ -31,6 +40,32 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
|
||||
}
|
||||
|
||||
public void testWriteResultsWithImportance() {
|
||||
List<FeatureImportance> importanceList = Stream.generate(FeatureImportanceTests::randomRegression)
|
||||
.limit(5)
|
||||
.collect(Collectors.toList());
|
||||
RegressionInferenceResults result = new RegressionInferenceResults(0.3,
|
||||
new RegressionConfig("predicted_value", 3),
|
||||
importanceList);
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
result.writeResult(document, "result_field");
|
||||
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
|
||||
"result_field.feature_importance",
|
||||
List.class);
|
||||
assertThat(writtenImportance, hasSize(3));
|
||||
importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
Map<String, Object> objectMap = writtenImportance.get(i);
|
||||
FeatureImportance importance = importanceList.get(i);
|
||||
assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
|
||||
assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
|
||||
assertThat(objectMap.size(), equalTo(2));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RegressionInferenceResults createTestInstance() {
|
||||
return createRandomResults();
|
||||
|
@ -684,45 +684,45 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
||||
.build();
|
||||
|
||||
|
||||
Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
Map<String, double[]> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(0.0798679, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(1.80491886, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(-0.4355742, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
|
||||
|
||||
featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
|
||||
assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
|
||||
}
|
||||
|
||||
|
||||
|
@ -390,22 +390,22 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
|
||||
TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
|
||||
TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
|
||||
|
||||
Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
|
||||
Map<String, double[]> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
|
||||
Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
|
||||
|
||||
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
|
||||
|
||||
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
|
||||
|
||||
featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
|
||||
assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
|
||||
assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
|
||||
assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
|
||||
assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
|
||||
}
|
||||
|
||||
public void testMaxFeatureIndex() {
|
||||
|
@ -156,8 +156,10 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
String responseString = EntityUtils.toString(response.getEntity());
|
||||
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
|
||||
assertThat(responseString, containsString("\"predicted_value\":1.0"));
|
||||
assertThat(responseString, containsString("\"col2\":0.944"));
|
||||
assertThat(responseString, containsString("\"col1\":0.19999"));
|
||||
assertThat(responseString, containsString("\"feature_name\":\"col1\""));
|
||||
assertThat(responseString, containsString("\"feature_name\":\"col2\""));
|
||||
assertThat(responseString, containsString("\"importance\":0.944"));
|
||||
assertThat(responseString, containsString("\"importance\":0.19999"));
|
||||
|
||||
String sourceWithMissingModel = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
@ -221,8 +223,10 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
Response response = client().performRequest(simulateRequest(source));
|
||||
String responseString = EntityUtils.toString(response.getEntity());
|
||||
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
|
||||
assertThat(responseString, containsString("\"col2\":0.944"));
|
||||
assertThat(responseString, containsString("\"col1\":0.19999"));
|
||||
assertThat(responseString, containsString("\"feature_name\":\"col1\""));
|
||||
assertThat(responseString, containsString("\"feature_name\":\"col2\""));
|
||||
assertThat(responseString, containsString("\"importance\":0.944"));
|
||||
assertThat(responseString, containsString("\"importance\":0.19999"));
|
||||
}
|
||||
|
||||
public void testSimulateLangIdent() throws IOException {
|
||||
|
@ -10,6 +10,7 @@ import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
@ -120,9 +121,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
|
||||
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
|
||||
|
||||
Map<String, Double> featureInfluence = new HashMap<>();
|
||||
featureInfluence.put("feature_1", 1.13);
|
||||
featureInfluence.put("feature_2", -42.0);
|
||||
List<FeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new ClassificationInferenceResults(1.0,
|
||||
@ -135,8 +136,10 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@ -205,9 +208,9 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
Map<String, Object> ingestMetadata = new HashMap<>();
|
||||
IngestDocument document = new IngestDocument(source, ingestMetadata);
|
||||
|
||||
Map<String, Double> featureInfluence = new HashMap<>();
|
||||
featureInfluence.put("feature_1", 1.13);
|
||||
featureInfluence.put("feature_2", -42.0);
|
||||
List<FeatureImportance> featureInfluence = new ArrayList<>();
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
|
||||
featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
|
||||
|
||||
InternalInferModelAction.Response response = new InternalInferModelAction.Response(
|
||||
Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
|
||||
@ -215,8 +218,10 @@ public class InferenceProcessorTests extends ESTestCase {
|
||||
|
||||
assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
|
||||
assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
|
||||
assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
|
||||
}
|
||||
|
||||
public void testGenerateRequestWithEmptyMapping() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user