[ML] binary classification per-class feature importance for model inference (#61597) (#61746)

This commit addresses two issues:

- per class feature importance is now written out for binary classification (logistic regression)
- The `class_name` in per class feature importance now matches what is written in the `top_classes` array.

backport of https://github.com/elastic/elasticsearch/pull/61597
This commit is contained in:
Benjamin Trent 2020-08-31 13:57:00 -04:00 committed by GitHub
parent 2858e1efc4
commit 8b33d8813a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 43 deletions

View File

@ -39,6 +39,12 @@ public class FeatureImportance implements Writeable, ToXContentObject {
return new FeatureImportance(featureName, importance, null); return new FeatureImportance(featureName, importance, null);
} }
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
importance,
classImportance);
}
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) { public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName, return new FeatureImportance(featureName,
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(), classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
@ -170,27 +176,27 @@ public class FeatureImportance implements Writeable, ToXContentObject {
} }
private static Map<String, Double> toMap(List<ClassImportance> importances) { private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance)); return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
} }
public static ClassImportance fromXContent(XContentParser parser) { public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private final String className; private final Object className;
private final double importance; private final double importance;
public ClassImportance(String className, double importance) { public ClassImportance(Object className, double importance) {
this.className = className; this.className = className;
this.importance = importance; this.importance = importance;
} }
public ClassImportance(StreamInput in) throws IOException { public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString(); this.className = in.readGenericValue();
this.importance = in.readDouble(); this.importance = in.readDouble();
} }
public String getClassName() { public Object getClassName() {
return className; return className;
} }
@ -207,7 +213,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(className); out.writeGenericValue(className);
out.writeDouble(importance); out.writeDouble(importance);
} }

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
@ -129,21 +130,46 @@ public final class InferenceHelpers {
return originalFeatureImportance; return originalFeatureImportance;
} }
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance, public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
@Nullable List<String> classificationLabels) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size()); List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
return importances;
}
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
final int predictedValue,
@Nullable List<String> classificationLabels,
@Nullable PredictionFieldType predictionFieldType) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
featureImportance.forEach((k, v) -> { featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression // This indicates logistic regression (binary classification)
// If the length > 1, we assume multi-class classification. // If the length > 1, we assume multi-class classification.
if (v.length == 1) { if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0])); assert predictedValue == 1 || predictedValue == 0;
// If predicted value is `1`, then the other class is `0`
// If predicted value is `0`, then the other class is `1`
final int otherClass = 1 - predictedValue;
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
importances.add(FeatureImportance.forBinaryClassification(k,
v[0],
Arrays.asList(
new FeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
v[0]),
new FeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)otherClass, otherLabel),
-v[0])
)));
} else { } else {
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length); List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
// If the classificationLabels exist, their length must match leaf_value length // If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length; assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) { for (int i = 0; i < v.length; i++) {
String label = classificationLabels == null ? null : classificationLabels.get(i);
classImportance.add(new FeatureImportance.ClassImportance( classImportance.add(new FeatureImportance.ClassImportance(
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), fieldType.transformPredictedValue((double)i, label),
v[i])); v[i]));
} }
importances.add(FeatureImportance.forClassification(k, classImportance)); importances.add(FeatureImportance.forClassification(k, classImportance));

View File

@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel {
RawInferenceResults inferenceResult = (RawInferenceResults) result; RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue(); inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) { if (config.requestingImportance()) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance(); addFeatureImportance(featureInfluence, inferenceResult);
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
} }
} }
double[] processed = outputAggregator.processValues(inferenceResults); double[] processed = outputAggregator.processValues(inferenceResults);
@ -176,6 +170,12 @@ public class EnsembleInferenceModel implements InferenceModel {
InferenceResults result = model.infer(features, subModelInferenceConfig); InferenceResults result = model.infer(features, subModelInferenceConfig);
assert result instanceof RawInferenceResults; assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result; RawInferenceResults inferenceResult = (RawInferenceResults) result;
addFeatureImportance(featureInfluence, inferenceResult);
}
return featureInfluence;
}
private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance(); double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length; assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) { for (int j = 0; j < modelFeatureImportance.length; j++) {
@ -185,8 +185,6 @@ public class EnsembleInferenceModel implements InferenceModel {
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]); featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
} }
} }
return featureInfluence;
}
private InferenceResults buildResults(double[] processedInferences, private InferenceResults buildResults(double[] processedInferences,
double[][] featureImportance, double[][] featureImportance,
@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
case REGRESSION: case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config, config,
transformFeatureImportance(decodedFeatureImportance, null)); transformFeatureImportanceRegression(decodedFeatureImportance));
case CLASSIFICATION: case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config; ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length; assert classificationWeights == null || processedInferences.length == classificationWeights.length;
@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel {
classificationConfig.getNumTopClasses(), classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType()); classificationConfig.getPredictionFieldType());
final InferenceHelpers.TopClassificationValue value = topClasses.v1(); final InferenceHelpers.TopClassificationValue value = topClasses.v1();
return new ClassificationInferenceResults((double)value.getValue(), return new ClassificationInferenceResults(value.getValue(),
classificationLabel(topClasses.v1().getValue(), classificationLabels), classificationLabel(topClasses.v1().getValue(), classificationLabels),
topClasses.v2(), topClasses.v2(),
transformFeatureImportance(decodedFeatureImportance, classificationLabels), transformFeatureImportanceClassification(decodedFeatureImportance,
value.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config, config,
value.getProbability(), value.getProbability(),
value.getScore()); value.getScore());

View File

@ -188,14 +188,17 @@ public class TreeInferenceModel implements InferenceModel {
return new ClassificationInferenceResults(classificationValue.getValue(), return new ClassificationInferenceResults(classificationValue.getValue(),
classificationLabel(classificationValue.getValue(), classificationLabels), classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(), topClasses.v2(),
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels), InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
classificationValue.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config, config,
classificationValue.getProbability(), classificationValue.getProbability(),
classificationValue.getScore()); classificationValue.getScore());
case REGRESSION: case REGRESSION:
return new RegressionInferenceResults(value[0], return new RegressionInferenceResults(value[0],
config, config,
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null)); InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
default: default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
} }

View File

@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
@ -185,8 +187,17 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) { private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME, ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields, ignoreUnknownFields,
a -> new ClassImportance((String)a[0], (Importance)a[1])); a -> new ClassImportance(a[0], (Importance)a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME); parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
parser.declareObject(ConstructingObjectParser.constructorArg(), parser.declareObject(ConstructingObjectParser.constructorArg(),
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER, ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
IMPORTANCE); IMPORTANCE);
@ -197,22 +208,22 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
} }
public final String className; public final Object className;
public final Importance importance; public final Importance importance;
public ClassImportance(StreamInput in) throws IOException { public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString(); this.className = in.readGenericValue();
this.importance = new Importance(in); this.importance = new Importance(in);
} }
ClassImportance(String className, Importance importance) { ClassImportance(Object className, Importance importance) {
this.className = className; this.className = className;
this.importance = importance; this.importance = importance;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(className); out.writeGenericValue(className);
importance.writeTo(out); importance.writeTo(out);
} }

View File

@ -17,6 +17,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import java.io.IOException; import java.io.IOException;
@ -154,10 +155,26 @@ public class InferenceDefinitionTests extends ESTestCase {
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config); ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second")); assertThat(results.valueAsString(), equalTo("second"));
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2")); FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001)); assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male")); assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001)); for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
} else {
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
}
}
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
} else {
assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
}
}
} }
public static String getClassificationDefinition(boolean customPreprocessor) { public static String getClassificationDefinition(boolean customPreprocessor) {