[7.x] Do not fail Evaluate API when the actual and predicted fields' types differ (#54255) (#54319)

This commit is contained in:
Przemysław Witek 2020-03-27 10:05:19 +01:00 committed by GitHub
parent c547fabb2b
commit d40afc7871
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 316 additions and 193 deletions

View File

@ -143,7 +143,7 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
if (result.get() == null) { // These are steps 2, 3, 4 etc.
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true)))
.toArray(KeyedFilter[]::new);
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
// too_many_buckets_exception exception is not thrown.
@ -154,7 +154,7 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
topActualClassNames.get().stream()
.skip(actualClasses.size())
.limit(actualClassesPerBatch)
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(actualField, className).lenient(true)))
.toArray(KeyedFilter[]::new);
if (keyedFiltersActual.length > 0) {
return Tuple.tuple(

View File

@ -108,7 +108,7 @@ public class Precision implements EvaluationMetric {
if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.map(className -> new KeyedFilter(className, QueryBuilders.matchQuery(predictedField, className).lenient(true)))
.toArray(KeyedFilter[]::new);
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
return Tuple.tuple(

View File

@ -29,23 +29,25 @@ import java.util.List;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notANumber;
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
private static final String ANIMAL_NAME_PREDICTION_FIELD = "animal_name_prediction";
private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction";
private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
private static final String NO_LEGS_PREDICTION_FIELD = "no_legs_prediction";
private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction";
private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
private static final String IS_PREDATOR_PREDICTION_FIELD = "predator_prediction";
private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction";
@Before
public void setup() {
@ -64,7 +66,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
public void testEvaluate_DefaultMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(
@ -78,7 +81,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_KEYWORD_FIELD,
ANIMAL_NAME_PREDICTION_FIELD,
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
@ -91,163 +94,257 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Recall.NAME.getPreferredName()));
}
public void testEvaluate_Accuracy_KeywordField() {
public void testEvaluate_AllMetrics_KeywordField_CaseSensitivity() {
String indexName = "some-index";
String actualField = "fieldA";
String predictedField = "fieldB";
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc",
actualField, "type=keyword",
predictedField, "type=keyword")
.get();
client().prepareIndex(indexName, "_doc")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(
actualField, "crocodile",
predictedField, "cRoCoDiLe")
.get();
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
indexName,
new Classification(
actualField,
predictedField,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(accuracyResult.getClasses(), contains(new Accuracy.PerClassResult("crocodile", 0.0)));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.0));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1);
assertThat(
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("ant", 47.0 / 75),
new Accuracy.PerClassResult("cat", 47.0 / 75),
new Accuracy.PerClassResult("dog", 47.0 / 75),
new Accuracy.PerClassResult("fox", 47.0 / 75),
new Accuracy.PerClassResult("mouse", 47.0 / 75))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new MulticlassConfusionMatrix.ActualClass(
"crocodile", 1, Arrays.asList(new MulticlassConfusionMatrix.PredictedClass("crocodile", 0L)), 1))));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
assertThat(precisionResult.getClasses(), empty());
assertThat(precisionResult.getAvgPrecision(), is(notANumber()));
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
assertThat(recallResult.getClasses(), contains(new Recall.PerClassResult("crocodile", 0.0)));
assertThat(recallResult.getAvgRecall(), equalTo(0.0));
}
private void evaluateAccuracy_IntegerField(String actualField) {
private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("1", 57.0 / 75),
new Accuracy.PerClassResult("2", 54.0 / 75),
new Accuracy.PerClassResult("3", 51.0 / 75),
new Accuracy.PerClassResult("4", 48.0 / 75),
new Accuracy.PerClassResult("5", 45.0 / 75))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
return accuracyResult;
}
public void testEvaluate_Accuracy_KeywordField() {
List<Accuracy.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Accuracy.PerClassResult("ant", 47.0 / 75),
new Accuracy.PerClassResult("cat", 47.0 / 75),
new Accuracy.PerClassResult("dog", 47.0 / 75),
new Accuracy.PerClassResult("fox", 47.0 / 75),
new Accuracy.PerClassResult("mouse", 47.0 / 75));
double expectedOverallAccuracy = 5.0 / 75;
Accuracy.Result accuracyResult = evaluateAccuracy(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
// Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric)
accuracyResult = evaluateAccuracy(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
}
public void testEvaluate_Accuracy_IntegerField() {
evaluateAccuracy_IntegerField(NO_LEGS_INTEGER_FIELD);
}
List<Accuracy.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Accuracy.PerClassResult("1", 57.0 / 75),
new Accuracy.PerClassResult("2", 54.0 / 75),
new Accuracy.PerClassResult("3", 51.0 / 75),
new Accuracy.PerClassResult("4", 48.0 / 75),
new Accuracy.PerClassResult("5", 45.0 / 75));
double expectedOverallAccuracy = 15.0 / 75;
public void testEvaluate_Accuracy_IntegerField_MappingTypeMismatch() {
evaluateAccuracy_IntegerField(NO_LEGS_KEYWORD_FIELD);
}
Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
private void evaluateAccuracy_BooleanField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Accuracy())));
// Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric)
accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
// Actual and predicted fields are of different types but the values are matched correctly
accuracyResult = evaluateAccuracy(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("false", 18.0 / 30),
new Accuracy.PerClassResult("true", 27.0 / 45))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
// Actual and predicted fields are of different types but the values are matched correctly
accuracyResult = evaluateAccuracy(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
}
public void testEvaluate_Accuracy_BooleanField() {
evaluateAccuracy_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
List<Accuracy.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Accuracy.PerClassResult("false", 18.0 / 30),
new Accuracy.PerClassResult("true", 27.0 / 45));
double expectedOverallAccuracy = 45.0 / 75;
Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
// Actual and predicted fields are swapped but this does not alter the result (accuracy is symmetric)
accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
// Actual and predicted fields are of different types but the values are matched correctly
accuracyResult = evaluateAccuracy(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
// Actual and predicted fields are of different types but the values are matched correctly
accuracyResult = evaluateAccuracy(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
}
public void testEvaluate_Accuracy_BooleanField_MappingTypeMismatch() {
evaluateAccuracy_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
public void testEvaluate_Accuracy_FieldTypeMismatch() {
{
// When actual and predicted fields have different types, the sets of classes are disjoint
List<Accuracy.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Accuracy.PerClassResult("1", 0.8),
new Accuracy.PerClassResult("2", 0.8),
new Accuracy.PerClassResult("3", 0.8),
new Accuracy.PerClassResult("4", 0.8),
new Accuracy.PerClassResult("5", 0.8));
double expectedOverallAccuracy = 0.0;
Accuracy.Result accuracyResult = evaluateAccuracy(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
}
{
// When actual and predicted fields have different types, the sets of classes are disjoint
List<Accuracy.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Accuracy.PerClassResult("false", 0.6),
new Accuracy.PerClassResult("true", 0.4));
double expectedOverallAccuracy = 0.0;
Accuracy.Result accuracyResult = evaluateAccuracy(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD);
assertThat(accuracyResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(expectedOverallAccuracy));
}
}
private Precision.Result evaluatePrecision(String actualField, String predictedField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
return precisionResult;
}
public void testEvaluate_Precision_KeywordField() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
List<Precision.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Precision.PerClassResult("ant", 1.0 / 15),
new Precision.PerClassResult("cat", 1.0 / 15),
new Precision.PerClassResult("dog", 1.0 / 15),
new Precision.PerClassResult("fox", 1.0 / 15),
new Precision.PerClassResult("mouse", 1.0 / 15));
double expectedAvgPrecision = 5.0 / 75;
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Precision.Result precisionResult = evaluatePrecision(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD);
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
new Precision.PerClassResult("ant", 1.0 / 15),
new Precision.PerClassResult("cat", 1.0 / 15),
new Precision.PerClassResult("dog", 1.0 / 15),
new Precision.PerClassResult("fox", 1.0 / 15),
new Precision.PerClassResult("mouse", 1.0 / 15))));
assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
}
private void evaluatePrecision_IntegerField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_PREDICTION_FIELD, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
new Precision.PerClassResult("1", 0.2),
new Precision.PerClassResult("2", 0.2),
new Precision.PerClassResult("3", 0.2),
new Precision.PerClassResult("4", 0.2),
new Precision.PerClassResult("5", 0.2))));
assertThat(precisionResult.getAvgPrecision(), equalTo(0.2));
evaluatePrecision(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD);
}
public void testEvaluate_Precision_IntegerField() {
evaluatePrecision_IntegerField(NO_LEGS_INTEGER_FIELD);
}
List<Precision.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Precision.PerClassResult("1", 0.2),
new Precision.PerClassResult("2", 0.2),
new Precision.PerClassResult("3", 0.2),
new Precision.PerClassResult("4", 0.2),
new Precision.PerClassResult("5", 0.2));
double expectedAvgPrecision = 0.2;
public void testEvaluate_Precision_IntegerField_MappingTypeMismatch() {
evaluatePrecision_IntegerField(NO_LEGS_KEYWORD_FIELD);
}
Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
private void evaluatePrecision_BooleanField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Precision())));
// Actual and predicted fields are of different types but the values are matched correctly
precisionResult = evaluatePrecision(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD);
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
new Precision.PerClassResult("false", 0.5),
new Precision.PerClassResult("true", 9.0 / 13))));
assertThat(precisionResult.getAvgPrecision(), equalTo(31.0 / 52));
evaluatePrecision(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD);
}
public void testEvaluate_Precision_BooleanField() {
evaluatePrecision_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
List<Precision.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Precision.PerClassResult("false", 0.5),
new Precision.PerClassResult("true", 9.0 / 13));
double expectedAvgPrecision = 31.0 / 52;
Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
// Actual and predicted fields are of different types but the values are matched correctly
precisionResult = evaluatePrecision(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
assertThat(precisionResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(precisionResult.getAvgPrecision(), equalTo(expectedAvgPrecision));
evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
evaluatePrecision(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD);
}
public void testEvaluate_Precision_BooleanField_MappingTypeMismatch() {
evaluatePrecision_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
public void testEvaluate_Precision_FieldTypeMismatch() {
{
Precision.Result precisionResult = evaluatePrecision(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
// When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here
assertThat(precisionResult.getClasses(), empty());
assertThat(precisionResult.getAvgPrecision(), is(notANumber()));
}
{
Precision.Result precisionResult = evaluatePrecision(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD);
// When actual and predicted fields have different types, the sets of classes are disjoint, hence empty results here
assertThat(precisionResult.getClasses(), empty());
assertThat(precisionResult.getAvgPrecision(), is(notANumber()));
}
}
public void testEvaluate_Precision_CardinalityTooHigh() {
@ -257,88 +354,112 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision()))));
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
}
public void testEvaluate_Recall_KeywordField() {
private Recall.Result evaluateRecall(String actualField, String predictedField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
new Recall.PerClassResult("ant", 1.0 / 15),
new Recall.PerClassResult("cat", 1.0 / 15),
new Recall.PerClassResult("dog", 1.0 / 15),
new Recall.PerClassResult("fox", 1.0 / 15),
new Recall.PerClassResult("mouse", 1.0 / 15))));
assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
return recallResult;
}
private void evaluateRecall_IntegerField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, NO_LEGS_INTEGER_FIELD, Arrays.asList(new Recall())));
public void testEvaluate_Recall_KeywordField() {
List<Recall.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Recall.PerClassResult("ant", 1.0 / 15),
new Recall.PerClassResult("cat", 1.0 / 15),
new Recall.PerClassResult("dog", 1.0 / 15),
new Recall.PerClassResult("fox", 1.0 / 15),
new Recall.PerClassResult("mouse", 1.0 / 15));
double expectedAvgRecall = 5.0 / 75;
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Recall.Result recallResult = evaluateRecall(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
new Recall.PerClassResult("1", 1.0),
new Recall.PerClassResult("2", 1.0),
new Recall.PerClassResult("3", 1.0),
new Recall.PerClassResult("4", 1.0),
new Recall.PerClassResult("5", 1.0))));
assertThat(recallResult.getAvgRecall(), equalTo(1.0));
evaluateRecall(ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, ANIMAL_NAME_KEYWORD_FIELD);
}
public void testEvaluate_Recall_IntegerField() {
evaluateRecall_IntegerField(NO_LEGS_INTEGER_FIELD);
}
List<Recall.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Recall.PerClassResult("1", 1.0 / 15),
new Recall.PerClassResult("2", 2.0 / 15),
new Recall.PerClassResult("3", 3.0 / 15),
new Recall.PerClassResult("4", 4.0 / 15),
new Recall.PerClassResult("5", 5.0 / 15));
double expectedAvgRecall = 3.0 / 15;
public void testEvaluate_Recall_IntegerField_MappingTypeMismatch() {
evaluateRecall_IntegerField(NO_LEGS_KEYWORD_FIELD);
}
Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
private void evaluateRecall_BooleanField(String actualField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(actualField, IS_PREDATOR_PREDICTION_FIELD, Arrays.asList(new Recall())));
// Actual and predicted fields are of different types but the values are matched correctly
recallResult = evaluateRecall(NO_LEGS_KEYWORD_FIELD, NO_LEGS_PREDICTION_INTEGER_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_INTEGER_FIELD);
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
new Recall.PerClassResult("true", 0.6),
new Recall.PerClassResult("false", 0.6))));
assertThat(recallResult.getAvgRecall(), equalTo(0.6));
evaluateRecall(NO_LEGS_PREDICTION_INTEGER_FIELD, NO_LEGS_KEYWORD_FIELD);
}
public void testEvaluate_Recall_BooleanField() {
evaluateRecall_BooleanField(IS_PREDATOR_BOOLEAN_FIELD);
List<Recall.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Recall.PerClassResult("true", 0.6),
new Recall.PerClassResult("false", 0.6));
double expectedAvgRecall = 0.6;
Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
// Actual and predicted fields are of different types but the values are matched correctly
recallResult = evaluateRecall(IS_PREDATOR_KEYWORD_FIELD, IS_PREDATOR_PREDICTION_BOOLEAN_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
evaluateRecall(IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, IS_PREDATOR_KEYWORD_FIELD);
}
public void testEvaluate_Recall_BooleanField_MappingTypeMismatch() {
evaluateRecall_BooleanField(IS_PREDATOR_KEYWORD_FIELD);
public void testEvaluate_Recall_FieldTypeMismatch() {
{
// When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here
List<Recall.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Recall.PerClassResult("1", 0.0),
new Recall.PerClassResult("2", 0.0),
new Recall.PerClassResult("3", 0.0),
new Recall.PerClassResult("4", 0.0),
new Recall.PerClassResult("5", 0.0));
double expectedAvgRecall = 0.0;
Recall.Result recallResult = evaluateRecall(NO_LEGS_INTEGER_FIELD, IS_PREDATOR_BOOLEAN_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
}
{
// When actual and predicted fields have different types, the sets of classes are disjoint, hence 0.0 results here
List<Recall.PerClassResult> expectedPerClassResults =
Arrays.asList(
new Recall.PerClassResult("true", 0.0),
new Recall.PerClassResult("false", 0.0));
double expectedAvgRecall = 0.0;
Recall.Result recallResult = evaluateRecall(IS_PREDATOR_BOOLEAN_FIELD, NO_LEGS_INTEGER_FIELD);
assertThat(recallResult.getClasses(), equalTo(expectedPerClassResults));
assertThat(recallResult.getAvgRecall(), equalTo(expectedAvgRecall));
}
}
public void testEvaluate_Recall_CardinalityTooHigh() {
@ -348,16 +469,16 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall()))));
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
}
private void evaluateWithMulticlassConfusionMatrix() {
private void evaluateMulticlassConfusionMatrix() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -417,16 +538,16 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
}
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
evaluateWithMulticlassConfusionMatrix();
evaluateMulticlassConfusionMatrix();
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get();
evaluateWithMulticlassConfusionMatrix();
evaluateMulticlassConfusionMatrix();
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get();
evaluateWithMulticlassConfusionMatrix();
evaluateMulticlassConfusionMatrix();
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get();
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix);
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateMulticlassConfusionMatrix);
assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class)));
TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause();
@ -438,7 +559,9 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
ANIMAL_NAME_KEYWORD_FIELD,
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
Arrays.asList(new MulticlassConfusionMatrix(3, null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -476,13 +599,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc",
ANIMAL_NAME_KEYWORD_FIELD, "type=keyword",
ANIMAL_NAME_PREDICTION_FIELD, "type=keyword",
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, "type=keyword",
NO_LEGS_KEYWORD_FIELD, "type=keyword",
NO_LEGS_INTEGER_FIELD, "type=integer",
NO_LEGS_PREDICTION_FIELD, "type=integer",
NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer",
IS_PREDATOR_KEYWORD_FIELD, "type=keyword",
IS_PREDATOR_BOOLEAN_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_FIELD, "type=boolean")
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean")
.get();
}
@ -497,13 +620,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
new IndexRequest(indexName)
.source(
ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i),
ANIMAL_NAME_PREDICTION_FIELD, animalNames.get((i + j) % animalNames.size()),
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()),
NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1),
NO_LEGS_INTEGER_FIELD, i + 1,
NO_LEGS_PREDICTION_FIELD, j + 1,
NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1,
IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0),
IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0,
IS_PREDATOR_PREDICTION_FIELD, (i + j) % 2 == 0));
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0));
}
}
}
@ -519,7 +642,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
for (int i = 0; i < distinctAnimalCount; i++) {
bulkRequestBuilder.add(
new IndexRequest(indexName)
.source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_FIELD, randomAlphaOfLength(5)));
.source(ANIMAL_NAME_KEYWORD_FIELD, "animal_" + i, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, randomAlphaOfLength(5)));
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {