[ML] fix inference binary classification predication label and feature importance (#63688) (#63930)

When calculating feature importance, the leaf values directly correlate the value of the importance.

Consequently, positive leaf values -> positive feature importance

negative leaf values -> negative feature importance.

It follows that for binary classification, this is done such that the importance relates to the leaf values, which relate directly to the "probability of class 1".

So, the feature importance calculated is always for the importance as it relates to class 1.

The inverse is the importance as it relates to class 0.
This commit is contained in:
Benjamin Trent 2020-10-20 08:50:15 -04:00 committed by GitHub
parent b5448f07f3
commit eff7f06ca6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 69 additions and 39 deletions

View File

@ -139,7 +139,6 @@ public final class InferenceHelpers {
public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(
Map<String, double[]> featureImportance,
final int predictedValue,
@Nullable List<String> classificationLabels,
@Nullable PredictionFieldType predictionFieldType) {
List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size());
@ -148,20 +147,20 @@ public final class InferenceHelpers {
// This indicates logistic regression (binary classification)
// If the length > 1, we assume multi-class classification.
if (v.length == 1) {
assert predictedValue == 1 || predictedValue == 0;
// If predicted value is `1`, then the other class is `0`
// If predicted value is `0`, then the other class is `1`
final int otherClass = 1 - predictedValue;
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
String zeroLabel = classificationLabels == null ? null : classificationLabels.get(0);
String oneLabel = classificationLabels == null ? null : classificationLabels.get(1);
// For feature importance, it is built off of the value in the leaves.
// These leaves indicate which direction the feature pulls the value
// The original importance is an indication of how it pushes or pulls the value towards or from `1`
// To get the importance for the `0` class, we simply invert it.
importances.add(new ClassificationFeatureImportance(k,
Arrays.asList(
new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
v[0]),
fieldType.transformPredictedValue(0.0, zeroLabel),
-v[0]),
new ClassificationFeatureImportance.ClassImportance(
fieldType.transformPredictedValue((double)otherClass, otherLabel),
-v[0])
fieldType.transformPredictedValue(1.0, oneLabel),
v[0])
)));
} else {
List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);

View File

@ -52,22 +52,25 @@ public enum PredictionFieldType implements Writeable {
case STRING:
return stringRep == null ? value.toString() : stringRep;
case BOOLEAN:
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) {
throw new IllegalArgumentException(
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]");
if (isNumberQuickCheck(stringRep)) {
try {
// 1 is true, 0 is false
return Integer.parseInt(stringRep) == 1;
} catch (NumberFormatException nfe) {
// do nothing, allow fall through to final fromDouble
}
} else if (isBoolQuickCheck(stringRep)) { // if we start with t/f case insensitive, it indicates boolean string
return Boolean.parseBoolean(stringRep);
}
return areClose(value, 1.0D);
return fromDouble(value);
case NUMBER:
if (Strings.isNullOrEmpty(stringRep)) {
return value;
}
// Quick check to verify that the string rep is LIKELY a number
// Still handles the case where it throws and then returns the underlying value
if (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))) {
if (isNumberQuickCheck(stringRep)) {
try {
return Long.parseLong(stringRep);
} catch (NumberFormatException nfe) {
return value;
// do nothing, allow fall through to final return
}
}
return value;
@ -76,7 +79,27 @@ public enum PredictionFieldType implements Writeable {
}
}
private static boolean fromDouble(double value) {
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) {
throw new IllegalArgumentException(
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]");
}
return areClose(value, 1.0D);
}
private static boolean areClose(double value1, double value2) {
return Math.abs(value1 - value2) < EPS;
}
private static boolean isNumberQuickCheck(String stringRep) {
return Strings.isNullOrEmpty(stringRep) == false && (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0)));
}
private static boolean isBoolQuickCheck(String stringRep) {
if (Strings.isNullOrEmpty(stringRep)) {
return false;
}
char c = stringRep.charAt(0);
return 't' == c || 'T' == c || 'f' == c || 'F' == c;
}
}

View File

@ -213,7 +213,6 @@ public class EnsembleInferenceModel implements InferenceModel {
classificationLabel(topClasses.v1().getValue(), classificationLabels),
topClasses.v2(),
transformFeatureImportanceClassification(decodedFeatureImportance,
value.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,

View File

@ -180,7 +180,6 @@ public class TreeInferenceModel implements InferenceModel {
classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
classificationValue.getValue(),
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,

View File

@ -13,17 +13,27 @@ import static org.hamcrest.Matchers.nullValue;
public class PredictionFieldTypeTests extends ESTestCase {
private static final String NOT_BOOLEAN = "not_boolean";
public void testTransformPredictedValueBoolean() {
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)),
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : NOT_BOOLEAN),
is(nullValue()));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : randomAlphaOfLength(10)),
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : NOT_BOOLEAN),
is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : randomAlphaOfLength(10)),
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : NOT_BOOLEAN),
is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "1"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "0"), is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "TruE"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "fAlsE"), is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "0"), is(false));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "1"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "TruE"), is(true));
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "fAlse"), is(false));
expectThrows(IllegalArgumentException.class,
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : randomAlphaOfLength(10)));
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : NOT_BOOLEAN));
expectThrows(IllegalArgumentException.class,
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : randomAlphaOfLength(10)));
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : NOT_BOOLEAN));
}
public void testTransformPredictedValueString() {

View File

@ -165,7 +165,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
.stream()
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
.collect(Collectors.toList()),
contains("not_to_be", "to_be"));
contains("no", "yes"));
// Get top classes
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true);
@ -174,14 +174,14 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
ClassificationInferenceResults classificationInferenceResults =
(ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("no"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("yes"));
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("no"));
// they should always be in order of Most probable to least
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
@ -192,7 +192,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes"));
}
public void testInferModelMultiClassModel() throws Exception {

View File

@ -114,13 +114,13 @@ public class LocalModelTests extends ESTestCase {
mock(CircuitBreaker.class));
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
assertThat(result.valueAsString(), equalTo("no"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
fields,
new ClassificationConfigUpdate(1, null, null, null, null));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("no"));
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
@ -169,11 +169,11 @@ public class LocalModelTests extends ESTestCase {
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
writeResult(result, document, "result_field", modelId);
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be"));
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("no"));
List<?> list = document.getFieldValue("result_field.top_classes", List.class);
assertThat(list.size(), equalTo(2));
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("not_to_be"));
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("to_be"));
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("no"));
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("yes"));
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER));
@ -440,7 +440,7 @@ public class LocalModelTests extends ESTestCase {
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.build();
return Ensemble.builder()
.setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null)
.setClassificationLabels(includeLabels ? Arrays.asList("no", "yes") : null)
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))