diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index ed82a4da476..ce8e9abcc2a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -139,7 +139,6 @@ public final class InferenceHelpers { public static List transformFeatureImportanceClassification( Map featureImportance, - final int predictedValue, @Nullable List classificationLabels, @Nullable PredictionFieldType predictionFieldType) { List importances = new ArrayList<>(featureImportance.size()); @@ -148,20 +147,20 @@ public final class InferenceHelpers { // This indicates logistic regression (binary classification) // If the length > 1, we assume multi-class classification. if (v.length == 1) { - assert predictedValue == 1 || predictedValue == 0; - // If predicted value is `1`, then the other class is `0` - // If predicted value is `0`, then the other class is `1` - final int otherClass = 1 - predictedValue; - String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue); - String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass); + String zeroLabel = classificationLabels == null ? null : classificationLabels.get(0); + String oneLabel = classificationLabels == null ? null : classificationLabels.get(1); + // For feature importance, it is built off of the value in the leaves. + // These leaves indicate which direction the feature pulls the value + // The original importance is an indication of how it pushes or pulls the value towards or from `1` + // To get the importance for the `0` class, we simply invert it. importances.add(new ClassificationFeatureImportance(k, Arrays.asList( new ClassificationFeatureImportance.ClassImportance( - fieldType.transformPredictedValue((double)predictedValue, predictedLabel), - v[0]), + fieldType.transformPredictedValue(0.0, zeroLabel), + -v[0]), new ClassificationFeatureImportance.ClassImportance( - fieldType.transformPredictedValue((double)otherClass, otherLabel), - -v[0]) + fieldType.transformPredictedValue(1.0, oneLabel), + v[0]) ))); } else { List classImportance = new ArrayList<>(v.length); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java index f0a8f08aa00..f94220b82ce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldType.java @@ -52,22 +52,25 @@ public enum PredictionFieldType implements Writeable { case STRING: return stringRep == null ? value.toString() : stringRep; case BOOLEAN: - if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) { - throw new IllegalArgumentException( - "Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]"); + if (isNumberQuickCheck(stringRep)) { + try { + // 1 is true, 0 is false + return Integer.parseInt(stringRep) == 1; + } catch (NumberFormatException nfe) { + // do nothing, allow fall through to final fromDouble + } + } else if (isBoolQuickCheck(stringRep)) { // if we start with t/f case insensitive, it indicates boolean string + return Boolean.parseBoolean(stringRep); } - return areClose(value, 1.0D); + return fromDouble(value); case NUMBER: - if (Strings.isNullOrEmpty(stringRep)) { - return value; - } // Quick check to verify that the string rep is LIKELY a number // Still handles the case where it throws and then returns the underlying value - if (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))) { + if (isNumberQuickCheck(stringRep)) { try { return Long.parseLong(stringRep); } catch (NumberFormatException nfe) { - return value; + // do nothing, allow fall through to final return } } return value; @@ -76,7 +79,27 @@ public enum PredictionFieldType implements Writeable { } } + private static boolean fromDouble(double value) { + if ((areClose(value, 1.0D) || areClose(value, 0.0D)) == false) { + throw new IllegalArgumentException( + "Cannot transform numbers other than 0.0 or 1.0 to boolean. Provided number [" + value + "]"); + } + return areClose(value, 1.0D); + } + private static boolean areClose(double value1, double value2) { return Math.abs(value1 - value2) < EPS; } + + private static boolean isNumberQuickCheck(String stringRep) { + return Strings.isNullOrEmpty(stringRep) == false && (stringRep.charAt(0) == '-' || Character.isDigit(stringRep.charAt(0))); + } + + private static boolean isBoolQuickCheck(String stringRep) { + if (Strings.isNullOrEmpty(stringRep)) { + return false; + } + char c = stringRep.charAt(0); + return 't' == c || 'T' == c || 'f' == c || 'F' == c; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index 72256b35e7f..acaa201d089 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -213,7 +213,6 @@ public class EnsembleInferenceModel implements InferenceModel { classificationLabel(topClasses.v1().getValue(), classificationLabels), topClasses.v2(), transformFeatureImportanceClassification(decodedFeatureImportance, - value.getValue(), classificationLabels, classificationConfig.getPredictionFieldType()), config, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index c4649503a49..294bd583c40 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -180,7 +180,6 @@ public class TreeInferenceModel implements InferenceModel { classificationLabel(classificationValue.getValue(), classificationLabels), topClasses.v2(), InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance, - classificationValue.getValue(), classificationLabels, classificationConfig.getPredictionFieldType()), config, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java index 0901ee62e74..cb7f5f1cb18 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PredictionFieldTypeTests.java @@ -13,17 +13,27 @@ import static org.hamcrest.Matchers.nullValue; public class PredictionFieldTypeTests extends ESTestCase { + private static final String NOT_BOOLEAN = "not_boolean"; + public void testTransformPredictedValueBoolean() { - assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : randomAlphaOfLength(10)), + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(null, randomBoolean() ? null : NOT_BOOLEAN), is(nullValue())); - assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : randomAlphaOfLength(10)), + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, randomBoolean() ? null : NOT_BOOLEAN), is(true)); - assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : randomAlphaOfLength(10)), + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, randomBoolean() ? null : NOT_BOOLEAN), is(false)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "1"), is(true)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "0"), is(false)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "TruE"), is(true)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(0.0, "fAlsE"), is(false)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "0"), is(false)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "1"), is(true)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "TruE"), is(true)); + assertThat(PredictionFieldType.BOOLEAN.transformPredictedValue(1.0, "fAlse"), is(false)); expectThrows(IllegalArgumentException.class, - () -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : randomAlphaOfLength(10))); + () -> PredictionFieldType.BOOLEAN.transformPredictedValue(0.1, randomBoolean() ? null : NOT_BOOLEAN)); expectThrows(IllegalArgumentException.class, - () -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : randomAlphaOfLength(10))); + () -> PredictionFieldType.BOOLEAN.transformPredictedValue(1.1, randomBoolean() ? null : NOT_BOOLEAN)); } public void testTransformPredictedValueString() { diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index ad69c051b80..b2d90043cc5 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -165,7 +165,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { .stream() .map(i -> ((SingleValueInferenceResults)i).valueAsString()) .collect(Collectors.toList()), - contains("not_to_be", "to_be")); + contains("no", "yes")); // Get top classes request = new InternalInferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true); @@ -174,14 +174,14 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { ClassificationInferenceResults classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); - assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); - assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("no")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("yes")); assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); - assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); - assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("no")); // they should always be in order of Most probable to least assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); @@ -192,7 +192,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase { classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); - assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("yes")); } public void testInferModelMultiClassModel() throws Exception { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 19f99edf77e..d57c2510d9d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -114,13 +114,13 @@ public class LocalModelTests extends ESTestCase { mock(CircuitBreaker.class)); result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); - assertThat(result.valueAsString(), equalTo("not_to_be")); + assertThat(result.valueAsString(), equalTo("no")); classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfigUpdate(1, null, null, null, null)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); - assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("no")); assertThat(model.getLatestStatsAndReset().getInferenceCount(), equalTo(2L)); classificationResult = (ClassificationInferenceResults)getSingleValue(model, @@ -169,11 +169,11 @@ public class LocalModelTests extends ESTestCase { IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); writeResult(result, document, "result_field", modelId); - assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("not_to_be")); + assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("no")); List list = document.getFieldValue("result_field.top_classes", List.class); assertThat(list.size(), equalTo(2)); - assertThat(((Map)list.get(0)).get("class_name"), equalTo("not_to_be")); - assertThat(((Map)list.get(1)).get("class_name"), equalTo("to_be")); + assertThat(((Map)list.get(0)).get("class_name"), equalTo("no")); + assertThat(((Map)list.get(1)).get("class_name"), equalTo("yes")); result = getInferenceResult(model, fields, new ClassificationConfigUpdate(2, null, null, null, PredictionFieldType.NUMBER)); @@ -440,7 +440,7 @@ public class LocalModelTests extends ESTestCase { .addNode(TreeNode.builder(2).setLeafValue(0.0)) .build(); return Ensemble.builder() - .setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null) + .setClassificationLabels(includeLabels ? Arrays.asList("no", "yes") : null) .setTargetType(TargetType.CLASSIFICATION) .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3))