When calculating feature importance, the leaf values directly correlate the value of the importance. Consequently, positive leaf values -> positive feature importance negative leaf values -> negative feature importance. It follows that for binary classification, this is done such that the importance relates to the leaf values, which relate directly to the "probability of class 1". So, the feature importance calculated is always for the importance as it relates to class 1. The inverse is the importance as it relates to class 0.
This commit is contained in:
parent
b5448f07f3
commit
eff7f06ca6
|
@ -139,7 +139,6 @@ public final class InferenceHelpers {
|
|||
|
||||
public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(
|
||||
Map<String, double[]> featureImportance,
|
||||
final int predictedValue,
|
||||
@Nullable List<String> classificationLabels,
|
||||
@Nullable PredictionFieldType predictionFieldType) {
|
||||
List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size());
|
||||
|
@ -148,20 +147,20 @@ public final class InferenceHelpers {
|
|||
// This indicates logistic regression (binary classification)
|
||||
// If the length > 1, we assume multi-class classification.
|
||||
if (v.length == 1) {
|
||||
assert predictedValue == 1 || predictedValue == 0;
|
||||
// If predicted value is `1`, then the other class is `0`
|
||||
// If predicted value is `0`, then the other class is `1`
|
||||
final int otherClass = 1 - predictedValue;
|
||||
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
|
||||
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
|
||||
String zeroLabel = classificationLabels == null ? null : classificationLabels.get(0);
|
||||
String oneLabel = classificationLabels == null ? null : classificationLabels.get(1);
|
||||
// For feature importance, it is built off of the value in the leaves.
|
||||
// These leaves indicate which direction the feature pulls the value
|
||||
// The original importance is an indication of how it pushes or pulls the value towards or from `1`
|
||||
// To get the importance for the `0` class, we simply invert it.
|
||||
importances.add(new ClassificationFeatureImportance(k,
|
||||
Arrays.asList(
|
||||
new ClassificationFeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
|
||||
v[0]),
|
||||
fieldType.transformPredictedValue(0.0, zeroLabel),
|
||||
-v[0]),
|
||||
new ClassificationFeatureImportance.ClassImportance(
|
||||
fieldType.transformPredictedValue((double)otherClass, otherLabel),
|
||||
-v[0])
|
||||
fieldType.transformPredictedValue(1.0, oneLabel),
|
||||
v[0])
|
||||
)));
|
||||
} else {
|
||||
List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
|
||||
|
|
|
@ -52,22 +52,25 @@ public enum PredictionFieldType implements Writeable {
|
|||
case STRING:
|
||||
return stringRep == null ? value.toString() : stringRep;
|
||||
case BOOLEAN:
|
||||
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) {
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]");
|
||||
if (isNumberQuickCheck(stringRep)) {
|
||||
try {
|
||||
// 1 is true, 0 is false
|
||||
return Integer.parseInt(stringRep) == 1;
|
||||
} catch (NumberFormatException nfe) {
|
||||
// do nothing, allow fall through to final fromDouble
|
||||
}
|
||||
} else if (isBoolQuickCheck(stringRep)) { // if we start with t/f case insensitive, it indicates boolean string
|
||||
return Boolean.parseBoolean(stringRep);
|
||||
}
|
||||
return areClose(value, 1.0D);
|
||||
return fromDouble(value);
|
||||
case NUMBER:
|
||||
if (Strings.isNullOrEmpty(stringRep)) {
|
||||
return value;
|
||||
}
|
||||
// Quick check to verify that the string rep is LIKELY a number
|
||||
// Still handles the case where it throws and then returns the underlying value
|
||||
if (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))) {
|
||||
if (isNumberQuickCheck(stringRep)) {
|
||||
try {
|
||||
return Long.parseLong(stringRep);
|
||||
} catch (NumberFormatException nfe) {
|
||||
return value;
|
||||
// do nothing, allow fall through to final return
|
||||
}
|
||||
}
|
||||
return value;
|
||||
|
@ -76,7 +79,27 @@ public enum PredictionFieldType implements Writeable {
|
|||
}
|
||||
}
|
||||
|
||||
private static boolean fromDouble(double value) {
|
||||
if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) {
|
||||
throw new IllegalArgumentException(
|
||||
"Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]");
|
||||
}
|
||||
return areClose(value, 1.0D);
|
||||
}
|
||||
|
||||
private static boolean areClose(double value1, double value2) {
|
||||
return Math.abs(value1 - value2) < EPS;
|
||||
}
|
||||
|
||||
private static boolean isNumberQuickCheck(String stringRep) {
|
||||
return Strings.isNullOrEmpty(stringRep) == false && (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0)));
|
||||
}
|
||||
|
||||
private static boolean isBoolQuickCheck(String stringRep) {
|
||||
if (Strings.isNullOrEmpty(stringRep)) {
|
||||
return false;
|
||||
}
|
||||
char c = stringRep.charAt(0);
|
||||
return 't' == c || 'T' == c || 'f' == c || 'F' == c;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -213,7 +213,6 @@ public class EnsembleInferenceModel implements InferenceModel {
|
|||
classificationLabel(topClasses.v1().getValue(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
transformFeatureImportanceClassification(decodedFeatureImportance,
|
||||
value.getValue(),
|
||||
classificationLabels,
|
||||
classificationConfig.getPredictionFieldType()),
|
||||
config,
|
||||
|
|
|
@ -180,7 +180,6 @@ public class TreeInferenceModel implements InferenceModel {
|
|||
classificationLabel(classificationValue.getValue(), classificationLabels),
|
||||
topClasses.v2(),
|
||||
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
|
||||
classificationValue.getValue(),
|
||||
classificationLabels,
|
||||
classificationConfig.getPredictionFieldType()),
|
||||
config,
|
||||
|
|
|
@ -13,17 +13,27 @@ import static org.hamcrest.Matchers.nullValue;
|
|||
|
||||
public class PredictionFieldTypeTests extends ESTestCase {
|
||||
|
||||
private static final String NOT_BOOLEAN = "not_boolean";
|
||||
|
||||
public void testTransformPredictedValueBoolean() {
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)),
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : NOT_BOOLEAN),
|
||||
is(nullValue()));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : randomAlphaOfLength(10)),
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : NOT_BOOLEAN),
|
||||
is(true));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : randomAlphaOfLength(10)),
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : NOT_BOOLEAN),
|
||||
is(false));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "1"), is(true));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "0"), is(false));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "TruE"), is(true));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "fAlsE"), is(false));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "0"), is(false));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "1"), is(true));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "TruE"), is(true));
|
||||
assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "fAlse"), is(false));
|
||||
expectThrows(IllegalArgumentException.class,
|
||||
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : randomAlphaOfLength(10)));
|
||||
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : NOT_BOOLEAN));
|
||||
expectThrows(IllegalArgumentException.class,
|
||||
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : randomAlphaOfLength(10)));
|
||||
() -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : NOT_BOOLEAN));
|
||||
}
|
||||
|
||||
public void testTransformPredictedValueString() {
|
||||
|
|
|
@ -165,7 +165,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
.stream()
|
||||
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
|
||||
.collect(Collectors.toList()),
|
||||
contains("not_to_be", "to_be"));
|
||||
contains("no", "yes"));
|
||||
|
||||
// Get top classes
|
||||
request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true);
|
||||
|
@ -174,14 +174,14 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
ClassificationInferenceResults classificationInferenceResults =
|
||||
(ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("no"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("yes"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
|
||||
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
|
||||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("no"));
|
||||
// they should always be in order of Most probable to least
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
|
||||
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
|
||||
|
@ -192,7 +192,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
|
|||
|
||||
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
|
||||
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
|
||||
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes"));
|
||||
}
|
||||
|
||||
public void testInferModelMultiClassModel() throws Exception {
|
||||
|
|
|
@ -114,13 +114,13 @@ public class LocalModelTests extends ESTestCase {
|
|||
mock(CircuitBreaker.class));
|
||||
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), equalTo("not_to_be"));
|
||||
assertThat(result.valueAsString(), equalTo("no"));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
||||
fields,
|
||||
new ClassificationConfigUpdate(1, null, null, null, null));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
|
||||
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("no"));
|
||||
assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L));
|
||||
|
||||
classificationResult = (ClassificationInferenceResults)getSingleValue(model,
|
||||
|
@ -169,11 +169,11 @@ public class LocalModelTests extends ESTestCase {
|
|||
|
||||
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
|
||||
writeResult(result, document, "result_field", modelId);
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be"));
|
||||
assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("no"));
|
||||
List<?> list = document.getFieldValue("result_field.top_classes", List.class);
|
||||
assertThat(list.size(), equalTo(2));
|
||||
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("not_to_be"));
|
||||
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("to_be"));
|
||||
assertThat(((Map<String, Object>)list.get(0)).get("class_name"), equalTo("no"));
|
||||
assertThat(((Map<String, Object>)list.get(1)).get("class_name"), equalTo("yes"));
|
||||
|
||||
result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER));
|
||||
|
||||
|
@ -440,7 +440,7 @@ public class LocalModelTests extends ESTestCase {
|
|||
.addNode(TreeNode.builder(2).setLeafValue(0.0))
|
||||
.build();
|
||||
return Ensemble.builder()
|
||||
.setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null)
|
||||
.setClassificationLabels(includeLabels ? Arrays.asList("no", "yes") : null)
|
||||
.setTargetType(TargetType.CLASSIFICATION)
|
||||
.setFeatureNames(featureNames)
|
||||
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
|
||||
|
|
Loading…
Reference in New Issue