[7.x] Make classification evaluation metrics work when there is field mapping type mismatch (#53458) (#53601)

This commit is contained in:
Przemysław Witek 2020-03-16 15:38:56 +01:00 committed by GitHub
parent e6680be0b1
commit 376b2ae735
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 230 additions and 53 deletions

View File

@ -28,11 +28,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
@ -66,12 +64,6 @@ public class Accuracy implements EvaluationMetric {
static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
public static Accuracy fromXContent(XContentParser parser) {
@ -112,7 +104,8 @@ public class Accuracy implements EvaluationMetric {
List<AggregationBuilder> aggs = new ArrayList<>();
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
if (overallAccuracy.get() == null) {
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script));
}
if (result.get() == null) {
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =

View File

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.script.Script;
import java.text.MessageFormat;
import java.util.Locale;
/**
* Painless scripts used by classification metrics in this package.
*/
final class PainlessScripts {
/**
* Template for the comparison script.
* It uses "String.valueOf" method in case the mapping types of the two fields are different.
*/
private static final MessageFormat COMPARISON_SCRIPT_TEMPLATE =
new MessageFormat("String.valueOf(doc[''{0}''].value).equals(String.valueOf(doc[''{1}''].value))", Locale.ROOT);
/**
* Builds script that tests field values equality for the given actual and predicted field names.
*
* @param actualField name of the actual field
* @param predictedField name of the predicted field
* @return script that tests whether the values of actualField and predictedField are equal
*/
static Script buildIsEqualScript(String actualField, String predictedField) {
return new Script(COMPARISON_SCRIPT_TEMPLATE.format(new Object[]{ actualField, predictedField }));
}
}

View File

@ -34,12 +34,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
@ -60,17 +58,12 @@ public class Precision implements EvaluationMetric {
public static final ParseField NAME = new ParseField("precision");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String AGG_NAME_PREFIX = "classification_precision_";
static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}
private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);
public static Precision fromXContent(XContentParser parser) {
@ -117,7 +110,7 @@ public class Precision implements EvaluationMetric {
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
Script script = buildScript(actualField, predictedField);
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)

View File

@ -20,6 +20,7 @@ import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
@ -30,12 +31,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
@ -55,16 +54,11 @@ public class Recall implements EvaluationMetric {
public static final ParseField NAME = new ParseField("recall");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String AGG_NAME_PREFIX = "classification_recall_";
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}
private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);
public static Recall fromXContent(XContentParser parser) {
@ -99,11 +93,12 @@ public class Recall implements EvaluationMetric {
if (result.get() != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
Script script = buildScript(actualField, predictedField);
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(MAX_CLASSES_CARDINALITY)
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
Arrays.asList(

View File

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.script.Script;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.equalTo;
public class PainlessScriptsTests extends ESTestCase {
public void testBuildIsEqualScript() {
Script script = PainlessScripts.buildIsEqualScript("act", "pred");
assertThat(script.getIdOrCode(), equalTo("String.valueOf(doc['act'].value).equals(String.valueOf(doc['pred'].value))"));
}
}

View File

@ -38,11 +38,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
private static final String ANIMAL_NAME_FIELD = "animal_name";
private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
private static final String NO_LEGS_FIELD = "no_legs";
private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
private static final String IS_PREDATOR_FIELD = "predator";
private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
@Before
@ -62,7 +64,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
public void testEvaluate_DefaultMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(
@ -75,7 +77,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_FIELD,
ANIMAL_NAME_KEYWORD_FIELD,
ANIMAL_NAME_PREDICTION_FIELD,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
@ -92,7 +94,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
public void testEvaluate_Accuracy_KeywordField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -111,10 +114,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
}
public void testEvaluate_Accuracy_IntegerField() {
private void evaluateAccuracy_IntegerField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(NO_LEGS_FIELD, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -133,10 +135,18 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
}
public void testEvaluate_Accuracy_BooleanField() {
public void testEvaluate_Accuracy_IntegerField() {
evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD);
}
public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() {
evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD);
}
private void evaluateAccuracy_BooleanField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(IS_PREDATOR_FIELD, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -152,10 +162,19 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
}
public void testEvaluate_Precision() {
public void testEvaluate_Accuracy_BooleanField() {
evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
}
public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() {
evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
}
public void testEvaluate_Precision_KeywordField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -174,6 +193,63 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
}
private void evaluatePrecision_IntegerField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
new Precision.PerClassResult("1", 0.2),
new Precision.PerClassResult("2", 0.2),
new Precision.PerClassResult("3", 0.2),
new Precision.PerClassResult("4", 0.2),
new Precision.PerClassResult("5", 0.2))));
assertThat(precisionResult.getAvgPrecision(), equalTo(0.2));
}
public void testEvaluate_Precision_IntegerField() {
evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD);
}
public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() {
evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD);
}
private void evaluatePrecision_BooleanField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
new Precision.PerClassResult("false", 0.5),
new Precision.PerClassResult("true", 9.0 / 13))));
assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52));
}
public void testEvaluate_Precision_BooleanField() {
evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
}
public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() {
evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
}
public void testEvaluate_Precision_CardinalityTooHigh() {
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
ElasticsearchStatusException e =
@ -181,14 +257,15 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
}
public void testEvaluate_Recall() {
public void testEvaluate_Recall_KeywordField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -207,21 +284,80 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
}
private void evaluateRecall_IntegerField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
new Recall.PerClassResult("1", 1.0),
new Recall.PerClassResult("2", 1.0),
new Recall.PerClassResult("3", 1.0),
new Recall.PerClassResult("4", 1.0),
new Recall.PerClassResult("5", 1.0))));
assertThat(recallResult.getAvgRecall(), equalTo(1.0));
}
public void testEvaluate_Recall_IntegerField() {
evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD);
}
public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() {
evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD);
}
private void evaluateRecall_BooleanField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
new Recall.PerClassResult("true", 0.6),
new Recall.PerClassResult("false", 0.6))));
assertThat(recallResult.getAvgRecall(), equalTo(0.6));
}
public void testEvaluate_Recall_BooleanField() {
evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
}
public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() {
evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
}
public void testEvaluate_Recall_CardinalityTooHigh() {
indexDistinctAnimals(ANIMALS_DATA_INDEX, 1001);
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
}
private void evaluateWithMulticlassConfusionMatrix() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
new Classification(
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -301,7 +437,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
new Classification(
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -338,11 +475,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
private static void createAnimalsIndex(String indexName) {
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc",
ANIMAL_NAME_FIELD, "type=keyword",
ANIMAL_NAME_KEYWORD_FIELD, "type=keyword",
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
NO_LEGS_FIELD, "type=integer",
NO_LEGS_KEYWORD_FIELD, "type=keyword",
NO_LEGS_INTEGER_FIELD, "type=integer",
NO_LEGS_PREDICTION_FIELD, "type=integer",
IS_PREDATOR_FIELD, "type=boolean",
IS_PREDATOR_KEYWORD_FIELD, "type=keyword",
IS_PREDATOR_BOOLEAN_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
.get();
}
@ -357,11 +496,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
bulkRequestBuilder.add(
new IndexRequest(indexName)
.source(
ANIMAL_NAME_FIELD, animalNames.get(i),
ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i),
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
NO_LEGS_FIELD, i + 1,
NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1),
NO_LEGS_INTEGER_FIELD, i + 1,
NO_LEGS_PREDICTION_FIELD, j + 1,
IS_PREDATOR_FIELD, i % 2 == 0,
IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0),
IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0,
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
}
}
@ -377,7 +518,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < distinctAnimalCount; i++) {
bulkRequestBuilder.add(
new IndexRequest(indexName).source(ANIMAL_NAME_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
new IndexRequest(indexName)
.source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {