diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 1dc3614723d..13c08098776 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -143,7 +143,7 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { if (result.get() == null) { // These are steps 2, 3, 4 etc. KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true))) .toArray(KeyedFilter[]::new); // Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that // too_many_buckets_exception exception is not thrown. @@ -154,7 +154,7 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { topActualClassNames.get().stream() .skip(actualClasses.size()) .limit(actualClassesPerBatch) - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(actualField, className).lenient(true))) .toArray(KeyedFilter[]::new); if (keyedFiltersActual.length > 0) { return Tuple.tuple( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 0ffdc22ab1c..b90bfd8cce6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -108,7 +108,7 @@ public class Precision implements EvaluationMetric { if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = topActualClassNames.get().stream() - .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true))) .toArray(KeyedFilter[]::new); Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); return Tuple.tuple( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 70f8e7ca8ae..6e135c9995d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -29,23 +29,25 @@ import java.util.List; import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notANumber; public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; - private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction"; + private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction"; private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; - private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction"; + private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction"; private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; - private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction"; + private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction"; @Before public void setup() { @@ -64,7 +66,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -78,7 +81,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT ANIMALS_DATA_INDEX, new Classification( ANIMAL_NAME_KEYWORD_FIELD, - ANIMAL_NAME_PREDICTION_FIELD, + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); @@ -91,163 +94,257 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT Recall.NAME.getPreferredName())); } - public void testEvaluate_Accuracy_KeywordField() { + public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() { + String indexName = "some-index"; + String actualField = "fieldA"; + String predictedField = "fieldB"; + client().admin().indices().prepareCreate(indexName) + .addMapping("_doc", + actualField, "type=keyword", + predictedField, "type=keyword") + .get(); + client().prepareIndex(indexName, "_doc") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource( + actualField, "crocodile", + predictedField, "cRoCoDiLe") + .get(); + EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); - - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + indexName, + new Classification( + actualField, + predictedField, + Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat(accuracyResult.getClasses(), contains(new Accuracy.PerClassResult("crocodile", 0.0))); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.0)); + + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1); assertThat( - accuracyResult.getClasses(), - equalTo( - Arrays.asList( - new Accuracy.PerClassResult("ant", 47.0 / 75), - new Accuracy.PerClassResult("cat", 47.0 / 75), - new Accuracy.PerClassResult("dog", 47.0 / 75), - new Accuracy.PerClassResult("fox", 47.0 / 75), - new Accuracy.PerClassResult("mouse", 47.0 / 75)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); + confusionMatrixResult.getConfusionMatrix(), + equalTo(Arrays.asList( + new MulticlassConfusionMatrix.ActualClass( + "crocodile", 1, Arrays.asList(new MulticlassConfusionMatrix.PredictedClass("crocodile", 0L)), 1)))); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2); + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3); + assertThat(recallResult.getClasses(), contains(new Recall.PerClassResult("crocodile", 0.0))); + assertThat(recallResult.getAvgRecall(), equalTo(0.0)); } - private void evaluateAccuracy_IntegerField(String actualField) { + private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - Arrays.asList( - new Accuracy.PerClassResult("1", 57.0 / 75), - new Accuracy.PerClassResult("2", 54.0 / 75), - new Accuracy.PerClassResult("3", 51.0 / 75), - new Accuracy.PerClassResult("4", 48.0 / 75), - new Accuracy.PerClassResult("5", 45.0 / 75)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); + return accuracyResult; + } + + public void testEvaluate_Accuracy_KeywordField() { + List expectedPerClassResults = + Arrays.asList( + new Accuracy.PerClassResult("ant", 47.0 / 75), + new Accuracy.PerClassResult("cat", 47.0 / 75), + new Accuracy.PerClassResult("dog", 47.0 / 75), + new Accuracy.PerClassResult("fox", 47.0 / 75), + new Accuracy.PerClassResult("mouse", 47.0 / 75)); + double expectedOverallAccuracy = 5.0 / 75; + + Accuracy.Result accuracyResult = evaluateAccuracy(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } public void testEvaluate_Accuracy_IntegerField() { - evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + Arrays.asList( + new Accuracy.PerClassResult("1", 57.0 / 75), + new Accuracy.PerClassResult("2", 54.0 / 75), + new Accuracy.PerClassResult("3", 51.0 / 75), + new Accuracy.PerClassResult("4", 48.0 / 75), + new Accuracy.PerClassResult("5", 45.0 / 75)); + double expectedOverallAccuracy = 15.0 / 75; - public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() { - evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); - private void evaluateAccuracy_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); - Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - assertThat( - accuracyResult.getClasses(), - equalTo( - Arrays.asList( - new Accuracy.PerClassResult("false", 18.0 / 30), - new Accuracy.PerClassResult("true", 27.0 / 45)))); - assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } public void testEvaluate_Accuracy_BooleanField() { - evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + List expectedPerClassResults = + Arrays.asList( + new Accuracy.PerClassResult("false", 18.0 / 30), + new Accuracy.PerClassResult("true", 27.0 / 45)); + double expectedOverallAccuracy = 45.0 / 75; + + Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric) + accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + + // Actual and predicted fields are of different types but the values are matched correctly + accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); } - public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() { - evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Accuracy_FieldTypeMismatch() { + { + // When actual and predicted fields have different types, the sets of classes are disjoint + List expectedPerClassResults = + Arrays.asList( + new Accuracy.PerClassResult("1", 0.8), + new Accuracy.PerClassResult("2", 0.8), + new Accuracy.PerClassResult("3", 0.8), + new Accuracy.PerClassResult("4", 0.8), + new Accuracy.PerClassResult("5", 0.8)); + double expectedOverallAccuracy = 0.0; + + Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } + { + // When actual and predicted fields have different types, the sets of classes are disjoint + List expectedPerClassResults = + Arrays.asList( + new Accuracy.PerClassResult("false", 0.6), + new Accuracy.PerClassResult("true", 0.4)); + double expectedOverallAccuracy = 0.0; + + Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy)); + } + } + + private Precision.Result evaluatePrecision(String actualField, String predictedField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + return precisionResult; } public void testEvaluate_Precision_KeywordField() { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))); + List expectedPerClassResults = + Arrays.asList( + new Precision.PerClassResult("ant", 1.0 / 15), + new Precision.PerClassResult("cat", 1.0 / 15), + new Precision.PerClassResult("dog", 1.0 / 15), + new Precision.PerClassResult("fox", 1.0 / 15), + new Precision.PerClassResult("mouse", 1.0 / 15)); + double expectedAvgPrecision = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Precision.Result precisionResult = evaluatePrecision(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - Arrays.asList( - new Precision.PerClassResult("ant", 1.0 / 15), - new Precision.PerClassResult("cat", 1.0 / 15), - new Precision.PerClassResult("dog", 1.0 / 15), - new Precision.PerClassResult("fox", 1.0 / 15), - new Precision.PerClassResult("mouse", 1.0 / 15)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75)); - } - - private void evaluatePrecision_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision()))); - - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); - - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - Arrays.asList( - new Precision.PerClassResult("1", 0.2), - new Precision.PerClassResult("2", 0.2), - new Precision.PerClassResult("3", 0.2), - new Precision.PerClassResult("4", 0.2), - new Precision.PerClassResult("5", 0.2)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(0.2)); + evaluatePrecision(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); } public void testEvaluate_Precision_IntegerField() { - evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + Arrays.asList( + new Precision.PerClassResult("1", 0.2), + new Precision.PerClassResult("2", 0.2), + new Precision.PerClassResult("3", 0.2), + new Precision.PerClassResult("4", 0.2), + new Precision.PerClassResult("5", 0.2)); + double expectedAvgPrecision = 0.2; - public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() { - evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - private void evaluatePrecision_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision()))); + // Actual and predicted fields are of different types but the values are matched correctly + precisionResult = evaluatePrecision(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); - Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); - assertThat( - precisionResult.getClasses(), - equalTo( - Arrays.asList( - new Precision.PerClassResult("false", 0.5), - new Precision.PerClassResult("true", 9.0 / 13)))); - assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52)); + evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); } public void testEvaluate_Precision_BooleanField() { - evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + List expectedPerClassResults = + Arrays.asList( + new Precision.PerClassResult("false", 0.5), + new Precision.PerClassResult("true", 9.0 / 13)); + double expectedAvgPrecision = 31.0 / 52; + + Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); + + // Actual and predicted fields are of different types but the values are matched correctly + precisionResult = evaluatePrecision(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision)); + + evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + + evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); } - public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() { - evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Precision_FieldTypeMismatch() { + { + Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + // When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + } + { + Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + // When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here + assertThat(precisionResult.getClasses(), empty()); + assertThat(precisionResult.getAvgPrecision(), is(notANumber())); + } } public void testEvaluate_Precision_CardinalityTooHigh() { @@ -257,88 +354,112 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - public void testEvaluate_Recall_KeywordField() { + private Recall.Result evaluateRecall(String actualField, String predictedField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - Arrays.asList( - new Recall.PerClassResult("ant", 1.0 / 15), - new Recall.PerClassResult("cat", 1.0 / 15), - new Recall.PerClassResult("dog", 1.0 / 15), - new Recall.PerClassResult("fox", 1.0 / 15), - new Recall.PerClassResult("mouse", 1.0 / 15)))); - assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75)); + return recallResult; } - private void evaluateRecall_IntegerField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall()))); + public void testEvaluate_Recall_KeywordField() { + List expectedPerClassResults = + Arrays.asList( + new Recall.PerClassResult("ant", 1.0 / 15), + new Recall.PerClassResult("cat", 1.0 / 15), + new Recall.PerClassResult("dog", 1.0 / 15), + new Recall.PerClassResult("fox", 1.0 / 15), + new Recall.PerClassResult("mouse", 1.0 / 15)); + double expectedAvgRecall = 5.0 / 75; - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + Recall.Result recallResult = evaluateRecall(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - Arrays.asList( - new Recall.PerClassResult("1", 1.0), - new Recall.PerClassResult("2", 1.0), - new Recall.PerClassResult("3", 1.0), - new Recall.PerClassResult("4", 1.0), - new Recall.PerClassResult("5", 1.0)))); - assertThat(recallResult.getAvgRecall(), equalTo(1.0)); + evaluateRecall(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD); } public void testEvaluate_Recall_IntegerField() { - evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD); - } + List expectedPerClassResults = + Arrays.asList( + new Recall.PerClassResult("1", 1.0 / 15), + new Recall.PerClassResult("2", 2.0 / 15), + new Recall.PerClassResult("3", 3.0 / 15), + new Recall.PerClassResult("4", 4.0 / 15), + new Recall.PerClassResult("5", 5.0 / 15)); + double expectedAvgRecall = 3.0 / 15; - public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() { - evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD); - } + Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - private void evaluateRecall_BooleanField(String actualField) { - EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall()))); + // Actual and predicted fields are of different types but the values are matched correctly + recallResult = evaluateRecall(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); - assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD); - Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); - assertThat( - recallResult.getClasses(), - equalTo( - Arrays.asList( - new Recall.PerClassResult("true", 0.6), - new Recall.PerClassResult("false", 0.6)))); - assertThat(recallResult.getAvgRecall(), equalTo(0.6)); + evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD); } public void testEvaluate_Recall_BooleanField() { - evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + List expectedPerClassResults = + Arrays.asList( + new Recall.PerClassResult("true", 0.6), + new Recall.PerClassResult("false", 0.6)); + double expectedAvgRecall = 0.6; + + Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + + // Actual and predicted fields are of different types but the values are matched correctly + recallResult = evaluateRecall(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + + evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + + evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD); } - public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() { - evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + public void testEvaluate_Recall_FieldTypeMismatch() { + { + // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here + List expectedPerClassResults = + Arrays.asList( + new Recall.PerClassResult("1", 0.0), + new Recall.PerClassResult("2", 0.0), + new Recall.PerClassResult("3", 0.0), + new Recall.PerClassResult("4", 0.0), + new Recall.PerClassResult("5", 0.0)); + double expectedAvgRecall = 0.0; + + Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + } + { + // When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here + List expectedPerClassResults = + Arrays.asList( + new Recall.PerClassResult("true", 0.0), + new Recall.PerClassResult("false", 0.0)); + double expectedAvgRecall = 0.0; + + Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD); + assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults)); + assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall)); + } } public void testEvaluate_Recall_CardinalityTooHigh() { @@ -348,16 +469,16 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())))); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall())))); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - private void evaluateWithMulticlassConfusionMatrix() { + private void evaluateMulticlassConfusionMatrix() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -417,16 +538,16 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT } public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get(); - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get(); - evaluateWithMulticlassConfusionMatrix(); + evaluateMulticlassConfusionMatrix(); client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get(); - ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix); + ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateMulticlassConfusionMatrix); assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class))); TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause(); @@ -438,7 +559,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null)))); + ANIMAL_NAME_KEYWORD_FIELD, + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, + Arrays.asList(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -476,13 +599,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT client().admin().indices().prepareCreate(indexName) .addMapping("_doc", ANIMAL_NAME_KEYWORD_FIELD, "type=keyword", - ANIMAL_NAME_PREDICTION_FIELD, "type=keyword", + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, "type=keyword", NO_LEGS_KEYWORD_FIELD, "type=keyword", NO_LEGS_INTEGER_FIELD, "type=integer", - NO_LEGS_PREDICTION_FIELD, "type=integer", + NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer", IS_PREDATOR_KEYWORD_FIELD, "type=keyword", IS_PREDATOR_BOOLEAN_FIELD, "type=boolean", - IS_PREDATOR_PREDICTION_FIELD, "type=boolean") + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean") .get(); } @@ -497,13 +620,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT new IndexRequest(indexName) .source( ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i), - ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()), + ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()), NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1), NO_LEGS_INTEGER_FIELD, i + 1, - NO_LEGS_PREDICTION_FIELD, j + 1, + NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1, IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0), IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0, - IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0)); + IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0)); } } } @@ -519,7 +642,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT for (int i = 0; i < distinctAnimalCount; i++) { bulkRequestBuilder.add( new IndexRequest(indexName) - .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, randomAlphaOfLength(5))); } BulkResponse bulkResponse = bulkRequestBuilder.get(); if (bulkResponse.hasFailures()) {