[ML] fix inference binary classification predication label and feature importance (#63688) (#63930)

When calculating feature importance, the leaf values directly correlate the value of the importance.

Consequently, positive leaf values -> positive feature importance

negative leaf values -> negative feature importance.

It follows that for binary classification, this is done such that the importance relates to the leaf values, which relate directly to the "probability of class 1".

So, the feature importance calculated is always for the importance as it relates to class 1.

The inverse is the importance as it relates to class 0.
This commit is contained in:
Benjamin Trent 2020-10-20 08:50:15 -04:00 committed by GitHub
parent b5448f07f3
commit eff7f06ca6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 69 additions and 39 deletions

View File

@ -139,7 +139,6 @@ public final class InferenceHelpers {
public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification( public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(
Map<String, double[]> featureImportance, Map<String, double[]> featureImportance,
final int predictedValue,
@Nullable List<String> classificationLabels, @Nullable List<String> classificationLabels,
@Nullable PredictionFieldType predictionFieldType) { @Nullable PredictionFieldType predictionFieldType) {
List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size()); List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size());
@ -148,20 +147,20 @@ public final class InferenceHelpers {
// This indicates logistic regression (binary classification) // This indicates logistic regression (binary classification)
// If the length > 1, we assume multi-class classification. // If the length > 1, we assume multi-class classification.
if (v.length == 1) { if (v.length == 1) {
assert predictedValue == 1 || predictedValue == 0; String zeroLabel = classificationLabels == null ? null : classificationLabels.get(0);
// If predicted value is `1`, then the other class is `0` String oneLabel = classificationLabels == null ? null : classificationLabels.get(1);
// If predicted value is `0`, then the other class is `1` // For feature importance, it is built off of the value in the leaves.
final int otherClass = 1 - predictedValue; // These leaves indicate which direction the feature pulls the value
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue); // The original importance is an indication of how it pushes or pulls the value towards or from `1`
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass); // To get the importance for the `0` class, we simply invert it.
importances.add(new ClassificationFeatureImportance(k, importances.add(new ClassificationFeatureImportance(k,
Arrays.asList( Arrays.asList(
new ClassificationFeatureImportance.ClassImportance( new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)predictedValue, predictedLabel), fieldType.transformPredictedValue(0.0, zeroLabel),
v[0]), -v[0]),
new ClassificationFeatureImportance.ClassImportance( new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)otherClass, otherLabel), fieldType.transformPredictedValue(1.0, oneLabel),
-v[0]) v[0])
))); )));
} else { } else {
List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length); List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);

View File

@ -52,22 +52,25 @@ public enum PredictionFieldType implements Writeable {
case STRING: case STRING:
return stringRep == null ? value.toString() : stringRep; return stringRep == null ? value.toString() : stringRep;
case BOOLEAN: case BOOLEAN:
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) { if (isNumberQuickCheck(stringRep)) {
throw new IllegalArgumentException( try {
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]"); // 1 is true, 0 is false
return Integer.parseInt(stringRep) == 1;
} catch (NumberFormatException nfe) {
// do nothing, allow fall through to final fromDouble
} }
return areClose(value, 1.0D); } else if (isBoolQuickCheck(stringRep)) { // if we start with t/f case insensitive, it indicates boolean string
return Boolean.parseBoolean(stringRep);
}
return fromDouble(value);
case NUMBER: case NUMBER:
if (Strings.isNullOrEmpty(stringRep)) {
return value;
}
// Quick check to verify that the string rep is LIKELY a number // Quick check to verify that the string rep is LIKELY a number
// Still handles the case where it throws and then returns the underlying value // Still handles the case where it throws and then returns the underlying value
if (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))) { if (isNumberQuickCheck(stringRep)) {
try { try {
return Long.parseLong(stringRep); return Long.parseLong(stringRep);
} catch (NumberFormatException nfe) { } catch (NumberFormatException nfe) {
return value; // do nothing, allow fall through to final return
} }
} }
return value; return value;
@ -76,7 +79,27 @@ public enum PredictionFieldType implements Writeable {
} }
} }
private static boolean fromDouble(double value) {
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) {
throw new IllegalArgumentException(
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]");
}
return areClose(value, 1.0D);
}
private static boolean areClose(double value1, double value2) { private static boolean areClose(double value1, double value2) {
return Math.abs(value1 - value2) < EPS; return Math.abs(value1 - value2) < EPS;
} }
private static boolean isNumberQuickCheck(String stringRep) {
return Strings.isNullOrEmpty(stringRep) == false && (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0)));
}
private static boolean isBoolQuickCheck(String stringRep) {
if (Strings.isNullOrEmpty(stringRep)) {
return false;
}
char c = stringRep.charAt(0);
return 't' == c || 'T' == c || 'f' == c || 'F' == c;
}
} }

View File

@ -213,7 +213,6 @@ public class EnsembleInferenceModel implements InferenceModel {
classificationLabel(topClasses.v1().getValue(), classificationLabels), classificationLabel(topClasses.v1().getValue(), classificationLabels),
topClasses.v2(), topClasses.v2(),
transformFeatureImportanceClassification(decodedFeatureImportance, transformFeatureImportanceClassification(decodedFeatureImportance,
value.getValue(),
classificationLabels, classificationLabels,
classificationConfig.getPredictionFieldType()), classificationConfig.getPredictionFieldType()),
config, config,

View File

@ -180,7 +180,6 @@ public class TreeInferenceModel implements InferenceModel {
classificationLabel(classificationValue.getValue(), classificationLabels), classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(), topClasses.v2(),
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance, InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
classificationValue.getValue(),
classificationLabels, classificationLabels,
classificationConfig.getPredictionFieldType()), classificationConfig.getPredictionFieldType()),
config, config,

View File

@ -13,17 +13,27 @@ import static org.hamcrest.Matchers.nullValue;
public class PredictionFieldTypeTests extends ESTestCase { public class PredictionFieldTypeTests extends ESTestCase {
private static final String NOT_BOOLEAN = "not_boolean";
public void testTransformPredictedValueBoolean() { public void testTransformPredictedValueBoolean() {
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)), assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : NOT_BOOLEAN),
is(nullValue())); is(nullValue()));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : randomAlphaOfLength(10)), assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : NOT_BOOLEAN),
is(true)); is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : randomAlphaOfLength(10)), assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : NOT_BOOLEAN),
is(false)); is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "1"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "0"), is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "TruE"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "fAlsE"), is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "0"), is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "1"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "TruE"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "fAlse"), is(false));
expectThrows(IllegalArgumentException.class, expectThrows(IllegalArgumentException.class,
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : randomAlphaOfLength(10))); () -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : NOT_BOOLEAN));
expectThrows(IllegalArgumentException.class, expectThrows(IllegalArgumentException.class,
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : randomAlphaOfLength(10))); () -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : NOT_BOOLEAN));
} }
public void testTransformPredictedValueString() { public void testTransformPredictedValueString() {

View File

@ -165,7 +165,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.stream() .stream()
.map(i -> ((SingleValueInferenceResults)i).valueAsString()) .map(i -> ((SingleValueInferenceResults)i).valueAsString())
.collect(Collectors.toList()), .collect(Collectors.toList()),
contains("not_to_be", "to_be")); contains("no", "yes"));
// Get top classes // Get top classes
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true); request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true);
@ -174,14 +174,14 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
ClassificationInferenceResults classificationInferenceResults = ClassificationInferenceResults classificationInferenceResults =
(ClassificationInferenceResults)response.getInferenceResults().get(0); (ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("no"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be")); assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("yes"));
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be")); assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("no"));
// they should always be in order of Most probable to least // they should always be in order of Most probable to least
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
@ -192,7 +192,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes"));
} }
public void testInferModelMultiClassModel() throws Exception { public void testInferModelMultiClassModel() throws Exception {

View File

@ -114,13 +114,13 @@ public class LocalModelTests extends ESTestCase {
mock(CircuitBreaker.class)); mock(CircuitBreaker.class));
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0)); assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be")); assertThat(result.valueAsString(), equalTo("no"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields, fields,
new ClassificationConfigUpdate(1, null, null, null, null)); new ClassificationConfigUpdate(1, null, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("no"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L)); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, classificationResult = (ClassificationInferenceResults)getSingleValue(model,
@ -169,11 +169,11 @@ public class LocalModelTests extends ESTestCase {
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
writeResult(result, document, "result_field", modelId); writeResult(result, document, "result_field", modelId);
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be")); assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("no"));
List<?> list = document.getFieldValue("result_field.top_classes", List.class); List<?> list = document.getFieldValue("result_field.top_classes", List.class);
assertThat(list.size(), equalTo(2)); assertThat(list.size(), equalTo(2));
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("not_to_be")); assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("no"));
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("to_be")); assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("yes"));
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER)); result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER));
@ -440,7 +440,7 @@ public class LocalModelTests extends ESTestCase {
.addNode(TreeNode.builder(2).setLeafValue(0.0)) .addNode(TreeNode.builder(2).setLeafValue(0.0))
.build(); .build();
return Ensemble.builder() return Ensemble.builder()
.setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null) .setClassificationLabels(includeLabels ? Arrays.asList("no", "yes") : null)
.setTargetType(TargetType.CLASSIFICATION) .setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames) .setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3)) .setTrainedModels(Arrays.asList(tree1, tree2, tree3))