This commit addresses two issues: - per class feature importance is now written out for binary classification (logistic regression) - The `class_name` in per class feature importance now matches what is written in the `top_classes` array. backport of https://github.com/elastic/elasticsearch/pull/61597
This commit is contained in:
parent
2858e1efc4
commit
8b33d8813a
|
@ -39,6 +39,12 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
|||
return new FeatureImportance(featureName, importance, null);
|
||||
}
|
||||
|
||||
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
|
||||
return new FeatureImportance(featureName,
|
||||
importance,
|
||||
classImportance);
|
||||
}
|
||||
|
||||
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
|
||||
return new FeatureImportance(featureName,
|
||||
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
|
||||
|
@ -170,27 +176,27 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
|||
}
|
||||
|
||||
private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
||||
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
|
||||
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
|
||||
}
|
||||
|
||||
public static ClassImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final String className;
|
||||
private final Object className;
|
||||
private final double importance;
|
||||
|
||||
public ClassImportance(String className, double importance) {
|
||||
public ClassImportance(Object className, double importance) {
|
||||
this.className = className;
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
public ClassImportance(StreamInput in) throws IOException {
|
||||
this.className = in.readString();
|
||||
this.className = in.readGenericValue();
|
||||
this.importance = in.readDouble();
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
public Object getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
|
@ -207,7 +213,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(className);
|
||||
out.writeGenericValue(className);
|
||||
out.writeDouble(importance);
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
|
@ -129,21 +130,46 @@ public final class InferenceHelpers {
|
|||
return originalFeatureImportance;
|
||||
}
|
||||
|
||||
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
|
||||
@Nullable List<String> classificationLabels) {
|
||||
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
|
||||
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
|
||||
return importances;
|
||||
}
|
||||
|
||||
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
|
||||
final int predictedValue,
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable PredictionFieldType predictionFieldType) {
|
||||
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
|
||||
featureImportance.forEach((k, v) -> {
|
||||
// This indicates regression, or logistic regression
|
||||
// This indicates logistic regression (binary classification)
|
||||
// If the length > 1, we assume multi-class classification.
|
||||
if (v.length == 1) {
|
||||
importances.add(FeatureImportance.forRegression(k, v[0]));
|
||||
assert predictedValue == 1 || predictedValue == 0;
|
||||
// If predicted value is `1`, then the other class is `0`
|
||||
// If predicted value is `0`, then the other class is `1`
|
||||
final int otherClass = 1 - predictedValue;
|
||||
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
|
||||
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
|
||||
importances.add(FeatureImportance.forBinaryClassification(k,
|
||||
v[0],
|
||||
Arrays.asList(
|
||||
new FeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
|
||||
v[0]),
|
||||
new FeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)otherClass, otherLabel),
|
||||
-v[0])
|
||||
)));
|
||||
} else {
|
||||
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
||||
// If the classificationLabels exist, their length must match leaf_value length
|
||||
assert classificationLabels == null || classificationLabels.size() == v.length;
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
String label = classificationLabels == null ? null : classificationLabels.get(i);
|
||||
classImportance.add(new FeatureImportance.ClassImportance(
|
||||
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
|
||||
fieldType.transformPredictedValue((double)i, label),
|
||||
v[i]));
|
||||
}
|
||||
importances.add(FeatureImportance.forClassification(k, classImportance));
|
||||
|
|
|
@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
|
|||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
|
||||
|
@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
||||
inferenceResults[i++] = inferenceResult.getValue();
|
||||
if (config.requestingImportance()) {
|
||||
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
||||
assert modelFeatureImportance.length == featureInfluence.length;
|
||||
for (int j = 0; j < modelFeatureImportance.length; j++) {
|
||||
if (featureInfluence[j] == null) {
|
||||
featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
||||
}
|
||||
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
||||
}
|
||||
addFeatureImportance(featureInfluence, inferenceResult);
|
||||
}
|
||||
}
|
||||
double[] processed = outputAggregator.processValues(inferenceResults);
|
||||
|
@ -176,18 +170,22 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
InferenceResults result = model.infer(features, subModelInferenceConfig);
|
||||
assert result instanceof RawInferenceResults;
|
||||
RawInferenceResults inferenceResult = (RawInferenceResults) result;
|
||||
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
||||
assert modelFeatureImportance.length == featureInfluence.length;
|
||||
for (int j = 0; j < modelFeatureImportance.length; j++) {
|
||||
if (featureInfluence[j] == null) {
|
||||
featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
||||
}
|
||||
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
||||
}
|
||||
addFeatureImportance(featureInfluence, inferenceResult);
|
||||
}
|
||||
return featureInfluence;
|
||||
}
|
||||
|
||||
private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
|
||||
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
|
||||
assert modelFeatureImportance.length == featureInfluence.length;
|
||||
for (int j = 0; j < modelFeatureImportance.length; j++) {
|
||||
if (featureInfluence[j] == null) {
|
||||
featureInfluence[j] = new double[modelFeatureImportance[j].length];
|
||||
}
|
||||
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
|
||||
}
|
||||
}
|
||||
|
||||
private InferenceResults buildResults(double[] processedInferences,
|
||||
double[][] featureImportance,
|
||||
Map<String, String> featureDecoderMap,
|
||||
|
@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
case REGRESSION:
|
||||
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
|
||||
config,
|
||||
transformFeatureImportance(decodedFeatureImportance, null));
|
||||
transformFeatureImportanceRegression(decodedFeatureImportance));
|
||||
case CLASSIFICATION:
|
||||
ClassificationConfig classificationConfig = (ClassificationConfig) config;
|
||||
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
|
||||
|
@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
classificationConfig.getNumTopClasses(),
|
||||
classificationConfig.getPredictionFieldType());
|
||||
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
|
||||
return new ClassificationInferenceResults((double)value.getValue(),
|
||||
return new ClassificationInferenceResults(value.getValue(),
|
||||
classificationLabel(topClasses.v1().getValue(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
||||
transformFeatureImportanceClassification(decodedFeatureImportance,
|
||||
value.getValue(),
|
||||
classificationLabels,
|
||||
classificationConfig.getPredictionFieldType()),
|
||||
config,
|
||||
value.getProbability(),
|
||||
value.getScore());
|
||||
|
|
|
@ -188,14 +188,17 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
return new ClassificationInferenceResults(classificationValue.getValue(),
|
||||
classificationLabel(classificationValue.getValue(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
|
||||
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
|
||||
classificationValue.getValue(),
|
||||
classificationLabels,
|
||||
classificationConfig.getPredictionFieldType()),
|
||||
config,
|
||||
classificationValue.getProbability(),
|
||||
classificationValue.getScore());
|
||||
case REGRESSION:
|
||||
return new RegressionInferenceResults(value[0],
|
||||
config,
|
||||
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
|
||||
InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
|
||||
default:
|
||||
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
||||
}
|
||||
|
|
|
@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -185,8 +187,17 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
|
||||
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
|
||||
ignoreUnknownFields,
|
||||
a -> new ClassImportance((String)a[0], (Importance)a[1]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
|
||||
a -> new ClassImportance(a[0], (Importance)a[1]));
|
||||
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return p.text();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||
return p.numberValue();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
return p.booleanValue();
|
||||
}
|
||||
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||
parser.declareObject(ConstructingObjectParser.constructorArg(),
|
||||
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
|
||||
IMPORTANCE);
|
||||
|
@ -197,22 +208,22 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
|
||||
}
|
||||
|
||||
public final String className;
|
||||
public final Object className;
|
||||
public final Importance importance;
|
||||
|
||||
public ClassImportance(StreamInput in) throws IOException {
|
||||
this.className = in.readString();
|
||||
this.className = in.readGenericValue();
|
||||
this.importance = new Importance(in);
|
||||
}
|
||||
|
||||
ClassImportance(String className, Importance importance) {
|
||||
ClassImportance(Object className, Importance importance) {
|
||||
this.className = className;
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(className);
|
||||
out.writeGenericValue(className);
|
||||
importance.writeTo(out);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -154,10 +155,26 @@ public class InferenceDefinitionTests extends ESTestCase {
|
|||
|
||||
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
|
||||
assertThat(results.valueAsString(), equalTo("second"));
|
||||
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
|
||||
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
|
||||
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
|
||||
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
|
||||
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
|
||||
assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
|
||||
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
|
||||
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
|
||||
if (classImportance.getClassName().equals("second")) {
|
||||
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
|
||||
} else {
|
||||
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
|
||||
}
|
||||
}
|
||||
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
|
||||
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
|
||||
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
|
||||
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
|
||||
if (classImportance.getClassName().equals("second")) {
|
||||
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
|
||||
} else {
|
||||
assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static String getClassificationDefinition(boolean customPreprocessor) {
|
||||
|
|
Loading…
Reference in New Issue