From 8b33d8813a520ec3553cf7f6cac2c0e0f174499c Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 31 Aug 2020 13:57:00 -0400 Subject: [PATCH] [ML] binary classification per-class feature importance for model inference (#61597) (#61746) This commit addresses two issues: - per class feature importance is now written out for binary classification (logistic regression) - The `class_name` in per class feature importance now matches what is written in the `top_classes` array. backport of https://github.com/elastic/elasticsearch/pull/61597 --- .../inference/results/FeatureImportance.java | 18 +++++--- .../trainedmodel/InferenceHelpers.java | 36 +++++++++++++--- .../inference/EnsembleInferenceModel.java | 41 ++++++++++--------- .../inference/TreeInferenceModel.java | 7 +++- .../metadata/TotalFeatureImportance.java | 23 ++++++++--- .../inference/InferenceDefinitionTests.java | 25 +++++++++-- 6 files changed, 107 insertions(+), 43 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java index 3c1a395a1f7..0846acf3331 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java @@ -39,6 +39,12 @@ public class FeatureImportance implements Writeable, ToXContentObject { return new FeatureImportance(featureName, importance, null); } + public static FeatureImportance forBinaryClassification(String featureName, double importance, List classImportance) { + return new FeatureImportance(featureName, + importance, + classImportance); + } + public static FeatureImportance forClassification(String featureName, List classImportance) { return new FeatureImportance(featureName, classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(), @@ -170,27 +176,27 @@ public class FeatureImportance implements Writeable, ToXContentObject { } private static Map toMap(List importances) { - return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance)); + return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance)); } public static ClassImportance fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private final String className; + private final Object className; private final double importance; - public ClassImportance(String className, double importance) { + public ClassImportance(Object className, double importance) { this.className = className; this.importance = importance; } public ClassImportance(StreamInput in) throws IOException { - this.className = in.readString(); + this.className = in.readGenericValue(); this.importance = in.readDouble(); } - public String getClassName() { + public Object getClassName() { return className; } @@ -207,7 +213,7 @@ public class FeatureImportance implements Writeable, ToXContentObject { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(className); + out.writeGenericValue(className); out.writeDouble(importance); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index d4cadf33bf4..0b5bf658cb1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -129,21 +130,46 @@ public final class InferenceHelpers { return originalFeatureImportance; } - public static List transformFeatureImportance(Map featureImportance, - @Nullable List classificationLabels) { + public static List transformFeatureImportanceRegression(Map featureImportance) { List importances = new ArrayList<>(featureImportance.size()); + featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0]))); + return importances; + } + + public static List transformFeatureImportanceClassification(Map featureImportance, + final int predictedValue, + @Nullable List classificationLabels, + @Nullable PredictionFieldType predictionFieldType) { + List importances = new ArrayList<>(featureImportance.size()); + final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; featureImportance.forEach((k, v) -> { - // This indicates regression, or logistic regression + // This indicates logistic regression (binary classification) // If the length > 1, we assume multi-class classification. if (v.length == 1) { - importances.add(FeatureImportance.forRegression(k, v[0])); + assert predictedValue == 1 || predictedValue == 0; + // If predicted value is `1`, then the other class is `0` + // If predicted value is `0`, then the other class is `1` + final int otherClass = 1 - predictedValue; + String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue); + String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass); + importances.add(FeatureImportance.forBinaryClassification(k, + v[0], + Arrays.asList( + new FeatureImportance.ClassImportance( + fieldType.transformPredictedValue((double)predictedValue, predictedLabel), + v[0]), + new FeatureImportance.ClassImportance( + fieldType.transformPredictedValue((double)otherClass, otherLabel), + -v[0]) + ))); } else { List classImportance = new ArrayList<>(v.length); // If the classificationLabels exist, their length must match leaf_value length assert classificationLabels == null || classificationLabels.size() == v.length; for (int i = 0; i < v.length; i++) { + String label = classificationLabels == null ? null : classificationLabels.get(i); classImportance.add(new FeatureImportance.ClassImportance( - classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), + fieldType.transformPredictedValue((double)i, label), v[i])); } importances.add(FeatureImportance.forClassification(k, classImportance)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index 2e65215f441..e43fd4a9b56 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS; @@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel { RawInferenceResults inferenceResult = (RawInferenceResults) result; inferenceResults[i++] = inferenceResult.getValue(); if (config.requestingImportance()) { - double[][] modelFeatureImportance = inferenceResult.getFeatureImportance(); - assert modelFeatureImportance.length == featureInfluence.length; - for (int j = 0; j < modelFeatureImportance.length; j++) { - if (featureInfluence[j] == null) { - featureInfluence[j] = new double[modelFeatureImportance[j].length]; - } - featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]); - } + addFeatureImportance(featureInfluence, inferenceResult); } } double[] processed = outputAggregator.processValues(inferenceResults); @@ -176,18 +170,22 @@ public class EnsembleInferenceModel implements InferenceModel { InferenceResults result = model.infer(features, subModelInferenceConfig); assert result instanceof RawInferenceResults; RawInferenceResults inferenceResult = (RawInferenceResults) result; - double[][] modelFeatureImportance = inferenceResult.getFeatureImportance(); - assert modelFeatureImportance.length == featureInfluence.length; - for (int j = 0; j < modelFeatureImportance.length; j++) { - if (featureInfluence[j] == null) { - featureInfluence[j] = new double[modelFeatureImportance[j].length]; - } - featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]); - } + addFeatureImportance(featureInfluence, inferenceResult); } return featureInfluence; } + private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) { + double[][] modelFeatureImportance = inferenceResult.getFeatureImportance(); + assert modelFeatureImportance.length == featureInfluence.length; + for (int j = 0; j < modelFeatureImportance.length; j++) { + if (featureInfluence[j] == null) { + featureInfluence[j] = new double[modelFeatureImportance[j].length]; + } + featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]); + } + } + private InferenceResults buildResults(double[] processedInferences, double[][] featureImportance, Map featureDecoderMap, @@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel { case REGRESSION: return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config, - transformFeatureImportance(decodedFeatureImportance, null)); + transformFeatureImportanceRegression(decodedFeatureImportance)); case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; assert classificationWeights == null || processedInferences.length == classificationWeights.length; @@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel { classificationConfig.getNumTopClasses(), classificationConfig.getPredictionFieldType()); final InferenceHelpers.TopClassificationValue value = topClasses.v1(); - return new ClassificationInferenceResults((double)value.getValue(), + return new ClassificationInferenceResults(value.getValue(), classificationLabel(topClasses.v1().getValue(), classificationLabels), topClasses.v2(), - transformFeatureImportance(decodedFeatureImportance, classificationLabels), + transformFeatureImportanceClassification(decodedFeatureImportance, + value.getValue(), + classificationLabels, + classificationConfig.getPredictionFieldType()), config, value.getProbability(), value.getScore()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index ad05412134f..ac35f0cff4a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -188,14 +188,17 @@ public class TreeInferenceModel implements InferenceModel { return new ClassificationInferenceResults(classificationValue.getValue(), classificationLabel(classificationValue.getValue(), classificationLabels), topClasses.v2(), - InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels), + InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance, + classificationValue.getValue(), + classificationLabels, + classificationConfig.getPredictionFieldType()), config, classificationValue.getProbability(), classificationValue.getScore()); case REGRESSION: return new RegressionInferenceResults(value[0], config, - InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null)); + InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance)); default: throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java index 4fe3464f8e9..9f2df2b7512 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; @@ -185,8 +187,17 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, - a -> new ClassImportance((String)a[0], (Importance)a[1])); - parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME); + a -> new ClassImportance(a[0], (Importance)a[1])); + parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return p.text(); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return p.numberValue(); + } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { + return p.booleanValue(); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_NAME, ObjectParser.ValueType.VALUE); parser.declareObject(ConstructingObjectParser.constructorArg(), ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER, IMPORTANCE); @@ -197,22 +208,22 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); } - public final String className; + public final Object className; public final Importance importance; public ClassImportance(StreamInput in) throws IOException { - this.className = in.readString(); + this.className = in.readGenericValue(); this.importance = new Importance(in); } - ClassImportance(String className, Importance importance) { + ClassImportance(Object className, Importance importance) { this.className = className; this.importance = importance; } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(className); + out.writeGenericValue(className); importance.writeTo(out); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java index e2faeaba8fe..6ecd7a8e212 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import java.io.IOException; @@ -154,10 +155,26 @@ public class InferenceDefinitionTests extends ESTestCase { ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config); assertThat(results.valueAsString(), equalTo("second")); - assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2")); - assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001)); - assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male")); - assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001)); + FeatureImportance featureImportance1 = results.getFeatureImportance().get(0); + assertThat(featureImportance1.getFeatureName(), equalTo("col2")); + assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001)); + for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) { + if (classImportance.getClassName().equals("second")) { + assertThat(classImportance.getImportance(), closeTo(0.944, 0.001)); + } else { + assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001)); + } + } + FeatureImportance featureImportance2 = results.getFeatureImportance().get(1); + assertThat(featureImportance2.getFeatureName(), equalTo("col1_male")); + assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001)); + for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) { + if (classImportance.getClassName().equals("second")) { + assertThat(classImportance.getImportance(), closeTo(0.199, 0.001)); + } else { + assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001)); + } + } } public static String getClassificationDefinition(boolean customPreprocessor) {