[ML] binary classification per-class feature importance for model inference (#61597) (#61746)

This commit addresses two issues:

- per class feature importance is now written out for binary classification (logistic regression)
- The `class_name` in per class feature importance now matches what is written in the `top_classes` array.

backport of https://github.com/elastic/elasticsearch/pull/61597
This commit is contained in:
Benjamin Trent 2020-08-31 13:57:00 -04:00 committed by GitHub
parent 2858e1efc4
commit 8b33d8813a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 43 deletions

View File

@ -39,6 +39,12 @@ public class FeatureImportance implements Writeable, ToXContentObject {
return new FeatureImportance(featureName, importance, null);
}
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
importance,
classImportance);
}
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
return new FeatureImportance(featureName,
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
@ -170,27 +176,27 @@ public class FeatureImportance implements Writeable, ToXContentObject {
}
private static Map<String, Double> toMap(List<ClassImportance> importances) {
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
}
public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final String className;
private final Object className;
private final double importance;
public ClassImportance(String className, double importance) {
public ClassImportance(Object className, double importance) {
this.className = className;
this.importance = importance;
}
public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString();
this.className = in.readGenericValue();
this.importance = in.readDouble();
}
public String getClassName() {
public Object getClassName() {
return className;
}
@ -207,7 +213,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeGenericValue(className);
out.writeDouble(importance);
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
@ -129,21 +130,46 @@ public final class InferenceHelpers {
return originalFeatureImportance;
}
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
@Nullable List<String> classificationLabels) {
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
return importances;
}
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
final int predictedValue,
@Nullable List<String> classificationLabels,
@Nullable PredictionFieldType predictionFieldType) {
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
featureImportance.forEach((k, v) -> {
// This indicates regression, or logistic regression
// This indicates logistic regression (binary classification)
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
importances.add(FeatureImportance.forRegression(k, v[0]));
assert predictedValue == 1 || predictedValue == 0;
// If predicted value is `1`, then the other class is `0`
// If predicted value is `0`, then the other class is `1`
final int otherClass = 1 - predictedValue;
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
importances.add(FeatureImportance.forBinaryClassification(k,
v[0],
Arrays.asList(
new FeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
v[0]),
new FeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)otherClass, otherLabel),
-v[0])
)));
} else {
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
// If the classificationLabels exist, their length must match leaf_value length
assert classificationLabels == null || classificationLabels.size() == v.length;
for (int i = 0; i < v.length; i++) {
String label = classificationLabels == null ? null : classificationLabels.get(i);
classImportance.add(new FeatureImportance.ClassImportance(
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
fieldType.transformPredictedValue((double)i, label),
v[i]));
}
importances.add(FeatureImportance.forClassification(k, classImportance));

View File

@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel {
RawInferenceResults inferenceResult = (RawInferenceResults) result;
inferenceResults[i++] = inferenceResult.getValue();
if (config.requestingImportance()) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
addFeatureImportance(featureInfluence, inferenceResult);
}
}
double[] processed = outputAggregator.processValues(inferenceResults);
@ -176,18 +170,22 @@ public class EnsembleInferenceModel implements InferenceModel {
InferenceResults result = model.infer(features, subModelInferenceConfig);
assert result instanceof RawInferenceResults;
RawInferenceResults inferenceResult = (RawInferenceResults) result;
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
addFeatureImportance(featureInfluence, inferenceResult);
}
return featureInfluence;
}
private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
assert modelFeatureImportance.length == featureInfluence.length;
for (int j = 0; j < modelFeatureImportance.length; j++) {
if (featureInfluence[j] == null) {
featureInfluence[j] = new double[modelFeatureImportance[j].length];
}
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
}
}
private InferenceResults buildResults(double[] processedInferences,
double[][] featureImportance,
Map<String, String> featureDecoderMap,
@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
config,
transformFeatureImportance(decodedFeatureImportance, null));
transformFeatureImportanceRegression(decodedFeatureImportance));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel {
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
return new ClassificationInferenceResults((double)value.getValue(),
return new ClassificationInferenceResults(value.getValue(),
classificationLabel(topClasses.v1().getValue(), classificationLabels),
topClasses.v2(),
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
transformFeatureImportanceClassification(decodedFeatureImportance,
value.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,
value.getProbability(),
value.getScore());

View File

@ -188,14 +188,17 @@ public class TreeInferenceModel implements InferenceModel {
return new ClassificationInferenceResults(classificationValue.getValue(),
classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
classificationValue.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,
classificationValue.getProbability(),
classificationValue.getScore());
case REGRESSION:
return new RegressionInferenceResults(value[0],
config,
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}

View File

@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
@ -185,8 +187,17 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new ClassImportance((String)a[0], (Importance)a[1]));
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
a -> new ClassImportance(a[0], (Importance)a[1]));
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
parser.declareObject(ConstructingObjectParser.constructorArg(),
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
IMPORTANCE);
@ -197,22 +208,22 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public final String className;
public final Object className;
public final Importance importance;
public ClassImportance(StreamInput in) throws IOException {
this.className = in.readString();
this.className = in.readGenericValue();
this.importance = new Importance(in);
}
ClassImportance(String className, Importance importance) {
ClassImportance(Object className, Importance importance) {
this.className = className;
this.importance = importance;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeGenericValue(className);
importance.writeTo(out);
}

View File

@ -17,6 +17,7 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import java.io.IOException;
@ -154,10 +155,26 @@ public class InferenceDefinitionTests extends ESTestCase {
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second"));
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
} else {
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
}
}
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
if (classImportance.getClassName().equals("second")) {
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
} else {
assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
}
}
}
public static String getClassificationDefinition(boolean customPreprocessor) {