diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 7aa95e14b57..440a46540ce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -28,11 +28,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.text.MessageFormat; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; @@ -66,12 +64,6 @@ public class Accuracy implements EvaluationMetric { static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; - - private static Script buildScript(Object...args) { - return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); - } - private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new); public static Accuracy fromXContent(XContentParser parser) { @@ -112,7 +104,8 @@ public class Accuracy implements EvaluationMetric { List aggs = new ArrayList<>(); List pipelineAggs = new ArrayList<>(); if (overallAccuracy.get() == null) { - aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); + Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); + aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script)); } if (result.get() == null) { Tuple, List> matrixAggs = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java new file mode 100644 index 00000000000..152bef11c09 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScripts.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.script.Script; + +import java.text.MessageFormat; +import java.util.Locale; + +/** + * Painless scripts used by classification metrics in this package. + */ +final class PainlessScripts { + + /** + * Template for the comparison script. + * It uses "String.valueOf" method in case the mapping types of the two fields are different. + */ + private static final MessageFormat COMPARISON_SCRIPT_TEMPLATE = + new MessageFormat("String.valueOf(doc[''{0}''].value).equals(String.valueOf(doc[''{1}''].value))", Locale.ROOT); + + /** + * Builds script that tests field values equality for the given actual and predicted field names. + * + * @param actualField name of the actual field + * @param predictedField name of the predicted field + * @return script that tests whether the values of actualField and predictedField are equal + */ + static Script buildIsEqualScript(String actualField, String predictedField) { + return new Script(COMPARISON_SCRIPT_TEMPLATE.format(new Object[]{ actualField, predictedField })); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index d3f0a259c16..0ffdc22ab1c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -34,12 +34,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.text.MessageFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -60,17 +58,12 @@ public class Precision implements EvaluationMetric { public static final ParseField NAME = new ParseField("precision"); - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; private static final String AGG_NAME_PREFIX = "classification_precision_"; static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class"; static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class"; static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision"; static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision"; - private static Script buildScript(Object...args) { - return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); - } - private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new); public static Precision fromXContent(XContentParser parser) { @@ -117,7 +110,7 @@ public class Precision implements EvaluationMetric { topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); - Script script = buildScript(actualField, predictedField); + Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); return Tuple.tuple( Arrays.asList( AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index fa5b277daa4..24319608150 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -20,6 +20,7 @@ import org.elasticsearch.script.Script; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.BucketOrder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders; import org.elasticsearch.search.aggregations.bucket.terms.Terms; @@ -30,12 +31,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.text.MessageFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Objects; import java.util.Optional; @@ -55,16 +54,11 @@ public class Recall implements EvaluationMetric { public static final ParseField NAME = new ParseField("recall"); - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; private static final String AGG_NAME_PREFIX = "classification_recall_"; static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class"; static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall"; static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall"; - private static Script buildScript(Object...args) { - return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); - } - private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new); public static Recall fromXContent(XContentParser parser) { @@ -99,11 +93,12 @@ public class Recall implements EvaluationMetric { if (result.get() != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } - Script script = buildScript(actualField, predictedField); + Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField); return Tuple.tuple( Arrays.asList( AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME) .field(actualField) + .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) .size(MAX_CLASSES_CARDINALITY) .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))), Arrays.asList( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScriptsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScriptsTests.java new file mode 100644 index 00000000000..5a397bae366 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PainlessScriptsTests.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.script.Script; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class PainlessScriptsTests extends ESTestCase { + + public void testBuildIsEqualScript() { + Script script = PainlessScripts.buildIsEqualScript("act", "pred"); + assertThat(script.getIdOrCode(), equalTo("String.valueOf(doc['act'].value).equals(String.valueOf(doc['pred'].value))")); + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 3985d9c1810..70f8e7ca8ae 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -38,11 +38,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; - private static final String ANIMAL_NAME_FIELD = "animal_name"; + private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword"; private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction"; - private static final String NO_LEGS_FIELD = "no_legs"; + private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword"; + private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer"; private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction"; - private static final String IS_PREDATOR_FIELD = "predator"; + private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword"; + private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean"; private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction"; @Before @@ -62,7 +64,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT public void testEvaluate_DefaultMetrics() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat( @@ -75,7 +77,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification( - ANIMAL_NAME_FIELD, + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); @@ -92,7 +94,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT public void testEvaluate_Accuracy_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -111,10 +114,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); } - public void testEvaluate_Accuracy_IntegerField() { + private void evaluateAccuracy_IntegerField(String actualField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = - evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -133,10 +135,18 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); } - public void testEvaluate_Accuracy_BooleanField() { + public void testEvaluate_Accuracy_IntegerField() { + evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD); + } + + public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() { + evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD); + } + + private void evaluateAccuracy_BooleanField(String actualField) { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); + ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -152,10 +162,19 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } - public void testEvaluate_Precision() { + public void testEvaluate_Accuracy_BooleanField() { + evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + } + + public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() { + evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + } + + public void testEvaluate_Precision_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -174,6 +193,63 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75)); } + private void evaluatePrecision_IntegerField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + assertThat( + precisionResult.getClasses(), + equalTo( + Arrays.asList( + new Precision.PerClassResult("1", 0.2), + new Precision.PerClassResult("2", 0.2), + new Precision.PerClassResult("3", 0.2), + new Precision.PerClassResult("4", 0.2), + new Precision.PerClassResult("5", 0.2)))); + assertThat(precisionResult.getAvgPrecision(), equalTo(0.2)); + } + + public void testEvaluate_Precision_IntegerField() { + evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD); + } + + public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() { + evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD); + } + + private void evaluatePrecision_BooleanField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName())); + assertThat( + precisionResult.getClasses(), + equalTo( + Arrays.asList( + new Precision.PerClassResult("false", 0.5), + new Precision.PerClassResult("true", 9.0 / 13)))); + assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52)); + } + + public void testEvaluate_Precision_BooleanField() { + evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + } + + public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() { + evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + } + public void testEvaluate_Precision_CardinalityTooHigh() { indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = @@ -181,14 +257,15 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT ElasticsearchStatusException.class, () -> evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())))); - assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } - public void testEvaluate_Recall() { + public void testEvaluate_Recall_KeywordField() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -207,21 +284,80 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75)); } + private void evaluateRecall_IntegerField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); + assertThat( + recallResult.getClasses(), + equalTo( + Arrays.asList( + new Recall.PerClassResult("1", 1.0), + new Recall.PerClassResult("2", 1.0), + new Recall.PerClassResult("3", 1.0), + new Recall.PerClassResult("4", 1.0), + new Recall.PerClassResult("5", 1.0)))); + assertThat(recallResult.getAvgRecall(), equalTo(1.0)); + } + + public void testEvaluate_Recall_IntegerField() { + evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD); + } + + public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() { + evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD); + } + + private void evaluateRecall_BooleanField(String actualField) { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName())); + assertThat( + recallResult.getClasses(), + equalTo( + Arrays.asList( + new Recall.PerClassResult("true", 0.6), + new Recall.PerClassResult("false", 0.6)))); + assertThat(recallResult.getAvgRecall(), equalTo(0.6)); + } + + public void testEvaluate_Recall_BooleanField() { + evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD); + } + + public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() { + evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD); + } + public void testEvaluate_Recall_CardinalityTooHigh() { indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> evaluateDataFrame( - ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())))); - assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); + ANIMALS_DATA_INDEX, + new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())))); + assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high")); } private void evaluateWithMulticlassConfusionMatrix() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); + new Classification( + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -301,7 +437,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null)))); + new Classification( + ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); @@ -338,11 +475,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT private static void createAnimalsIndex(String indexName) { client().admin().indices().prepareCreate(indexName) .addMapping("_doc", - ANIMAL_NAME_FIELD, "type=keyword", + ANIMAL_NAME_KEYWORD_FIELD, "type=keyword", ANIMAL_NAME_PREDICTION_FIELD, "type=keyword", - NO_LEGS_FIELD, "type=integer", + NO_LEGS_KEYWORD_FIELD, "type=keyword", + NO_LEGS_INTEGER_FIELD, "type=integer", NO_LEGS_PREDICTION_FIELD, "type=integer", - IS_PREDATOR_FIELD, "type=boolean", + IS_PREDATOR_KEYWORD_FIELD, "type=keyword", + IS_PREDATOR_BOOLEAN_FIELD, "type=boolean", IS_PREDATOR_PREDICTION_FIELD, "type=boolean") .get(); } @@ -357,11 +496,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT bulkRequestBuilder.add( new IndexRequest(indexName) .source( - ANIMAL_NAME_FIELD, animalNames.get(i), + ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i), ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()), - NO_LEGS_FIELD, i + 1, + NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1), + NO_LEGS_INTEGER_FIELD, i + 1, NO_LEGS_PREDICTION_FIELD, j + 1, - IS_PREDATOR_FIELD, i % 2 == 0, + IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0), + IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0, IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0)); } } @@ -377,7 +518,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < distinctAnimalCount; i++) { bulkRequestBuilder.add( - new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); + new IndexRequest(indexName) + .source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5))); } BulkResponse bulkResponse = bulkRequestBuilder.get(); if (bulkResponse.hasFailures()) {