[7.x] Make classification evaluation metrics work when there is field mapping type mismatch (#53458) (#53601)
This commit is contained in:
parent
e6680be0b1
commit
376b2ae735
|
@ -28,11 +28,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
|
@ -66,12 +64,6 @@ public class Accuracy implements EvaluationMetric {
|
|||
|
||||
static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
|
||||
private static Script buildScript(Object...args) {
|
||||
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||
}
|
||||
|
||||
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
|
||||
|
||||
public static Accuracy fromXContent(XContentParser parser) {
|
||||
|
@ -112,7 +104,8 @@ public class Accuracy implements EvaluationMetric {
|
|||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
|
||||
if (overallAccuracy.get() == null) {
|
||||
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
||||
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
|
||||
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script));
|
||||
}
|
||||
if (result.get() == null) {
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.script.Script;
|
||||
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Painless scripts used by classification metrics in this package.
|
||||
*/
|
||||
final class PainlessScripts {
|
||||
|
||||
/**
|
||||
* Template for the comparison script.
|
||||
* It uses "String.valueOf" method in case the mapping types of the two fields are different.
|
||||
*/
|
||||
private static final MessageFormat COMPARISON_SCRIPT_TEMPLATE =
|
||||
new MessageFormat("String.valueOf(doc[''{0}''].value).equals(String.valueOf(doc[''{1}''].value))", Locale.ROOT);
|
||||
|
||||
/**
|
||||
* Builds script that tests field values equality for the given actual and predicted field names.
|
||||
*
|
||||
* @param actualField name of the actual field
|
||||
* @param predictedField name of the predicted field
|
||||
* @return script that tests whether the values of actualField and predictedField are equal
|
||||
*/
|
||||
static Script buildIsEqualScript(String actualField, String predictedField) {
|
||||
return new Script(COMPARISON_SCRIPT_TEMPLATE.format(new Object[]{ actualField, predictedField }));
|
||||
}
|
||||
}
|
|
@ -34,12 +34,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -60,17 +58,12 @@ public class Precision implements EvaluationMetric {
|
|||
|
||||
public static final ParseField NAME = new ParseField("precision");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
private static final String AGG_NAME_PREFIX = "classification_precision_";
|
||||
static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
|
||||
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
|
||||
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
|
||||
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
|
||||
|
||||
private static Script buildScript(Object...args) {
|
||||
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||
}
|
||||
|
||||
private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);
|
||||
|
||||
public static Precision fromXContent(XContentParser parser) {
|
||||
|
@ -117,7 +110,7 @@ public class Precision implements EvaluationMetric {
|
|||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
Script script = buildScript(actualField, predictedField);
|
||||
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.script.Script;
|
|||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.BucketOrder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
|
@ -30,12 +31,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters
|
|||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
|
@ -55,16 +54,11 @@ public class Recall implements EvaluationMetric {
|
|||
|
||||
public static final ParseField NAME = new ParseField("recall");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
private static final String AGG_NAME_PREFIX = "classification_recall_";
|
||||
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
|
||||
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
|
||||
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
|
||||
|
||||
private static Script buildScript(Object...args) {
|
||||
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||
}
|
||||
|
||||
private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);
|
||||
|
||||
public static Recall fromXContent(XContentParser parser) {
|
||||
|
@ -99,11 +93,12 @@ public class Recall implements EvaluationMetric {
|
|||
if (result.get() != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
Script script = buildScript(actualField, predictedField);
|
||||
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(MAX_CLASSES_CARDINALITY)
|
||||
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
|
||||
Arrays.asList(
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class PainlessScriptsTests extends ESTestCase {
|
||||
|
||||
public void testBuildIsEqualScript() {
|
||||
Script script = PainlessScripts.buildIsEqualScript("act", "pred");
|
||||
assertThat(script.getIdOrCode(), equalTo("String.valueOf(doc['act'].value).equals(String.valueOf(doc['pred'].value))"));
|
||||
}
|
||||
}
|
|
@ -38,11 +38,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
|
||||
|
||||
private static final String ANIMAL_NAME_FIELD = "animal_name";
|
||||
private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
|
||||
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
|
||||
private static final String NO_LEGS_FIELD = "no_legs";
|
||||
private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
|
||||
private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
|
||||
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
|
||||
private static final String IS_PREDATOR_FIELD = "predator";
|
||||
private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
|
||||
private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
|
||||
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
|
||||
|
||||
@Before
|
||||
|
@ -62,7 +64,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
public void testEvaluate_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -75,7 +77,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(
|
||||
ANIMAL_NAME_FIELD,
|
||||
ANIMAL_NAME_KEYWORD_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_FIELD,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
|
||||
|
@ -92,7 +94,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
public void testEvaluate_Accuracy_KeywordField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -111,10 +114,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_IntegerField() {
|
||||
private void evaluateAccuracy_IntegerField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -133,10 +135,18 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_BooleanField() {
|
||||
public void testEvaluate_Accuracy_IntegerField() {
|
||||
evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() {
|
||||
evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
private void evaluateAccuracy_BooleanField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -152,10 +162,19 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision() {
|
||||
public void testEvaluate_Accuracy_BooleanField() {
|
||||
evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() {
|
||||
evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_KeywordField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -174,6 +193,63 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
private void evaluatePrecision_IntegerField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("1", 0.2),
|
||||
new Precision.PerClassResult("2", 0.2),
|
||||
new Precision.PerClassResult("3", 0.2),
|
||||
new Precision.PerClassResult("4", 0.2),
|
||||
new Precision.PerClassResult("5", 0.2))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(0.2));
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_IntegerField() {
|
||||
evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() {
|
||||
evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
private void evaluatePrecision_BooleanField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("false", 0.5),
|
||||
new Precision.PerClassResult("true", 9.0 / 13))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52));
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_BooleanField() {
|
||||
evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() {
|
||||
evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_CardinalityTooHigh() {
|
||||
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
|
||||
ElasticsearchStatusException e =
|
||||
|
@ -181,14 +257,15 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall() {
|
||||
public void testEvaluate_Recall_KeywordField() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -207,21 +284,80 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
private void evaluateRecall_IntegerField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("1", 1.0),
|
||||
new Recall.PerClassResult("2", 1.0),
|
||||
new Recall.PerClassResult("3", 1.0),
|
||||
new Recall.PerClassResult("4", 1.0),
|
||||
new Recall.PerClassResult("5", 1.0))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(1.0));
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_IntegerField() {
|
||||
evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() {
|
||||
evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
private void evaluateRecall_BooleanField(String actualField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("true", 0.6),
|
||||
new Recall.PerClassResult("false", 0.6))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(0.6));
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_BooleanField() {
|
||||
evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() {
|
||||
evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_CardinalityTooHigh() {
|
||||
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
|
||||
}
|
||||
|
||||
private void evaluateWithMulticlassConfusionMatrix() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -301,7 +437,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -338,11 +475,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
private static void createAnimalsIndex(String indexName) {
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc",
|
||||
ANIMAL_NAME_FIELD, "type=keyword",
|
||||
ANIMAL_NAME_KEYWORD_FIELD, "type=keyword",
|
||||
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
|
||||
NO_LEGS_FIELD, "type=integer",
|
||||
NO_LEGS_KEYWORD_FIELD, "type=keyword",
|
||||
NO_LEGS_INTEGER_FIELD, "type=integer",
|
||||
NO_LEGS_PREDICTION_FIELD, "type=integer",
|
||||
IS_PREDATOR_FIELD, "type=boolean",
|
||||
IS_PREDATOR_KEYWORD_FIELD, "type=keyword",
|
||||
IS_PREDATOR_BOOLEAN_FIELD, "type=boolean",
|
||||
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
|
||||
.get();
|
||||
}
|
||||
|
@ -357,11 +496,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName)
|
||||
.source(
|
||||
ANIMAL_NAME_FIELD, animalNames.get(i),
|
||||
ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i),
|
||||
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
|
||||
NO_LEGS_FIELD, i + 1,
|
||||
NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1),
|
||||
NO_LEGS_INTEGER_FIELD, i + 1,
|
||||
NO_LEGS_PREDICTION_FIELD, j + 1,
|
||||
IS_PREDATOR_FIELD, i % 2 == 0,
|
||||
IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0),
|
||||
IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0,
|
||||
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
|
||||
}
|
||||
}
|
||||
|
@ -377,7 +518,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < distinctAnimalCount; i++) {
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
|
||||
new IndexRequest(indexName)
|
||||
.source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
|
|
Loading…
Reference in New Issue