[7.x][ML] Assert top classes are ordered by score (#51028)

Backport #51003.
This commit is contained in:
Tom Veasey 2020-01-16 12:23:15 +00:00 committed by GitHub
parent b345c7ff31
commit 32ec934b15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 6 deletions

View File

@ -5,8 +5,7 @@
*/
package org.elasticsearch.xpack.ml.integration;
// Pending fix
//import com.google.common.collect.Ordering;
import com.google.common.collect.Ordering;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
@ -582,9 +581,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(topClasses, hasSize(numTopClasses));
List<T> classNames = new ArrayList<>(topClasses.size());
List<Double> classProbabilities = new ArrayList<>(topClasses.size());
List<Double> classScores = new ArrayList<>(topClasses.size());
for (Map<String, Object> topClass : topClasses) {
classNames.add(getFieldValue(topClass, "class_name"));
classProbabilities.add(getFieldValue(topClass, "class_probability"));
classScores.add(getFieldValue(topClass, "class_score"));
}
// Assert that all the predicted class names come from the set of dependent variable values.
classNames.forEach(className -> assertThat(className, is(in(dependentVariableValues))));
@ -592,10 +593,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval.
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
// Assert that the top classes are listed in the order of decreasing probabilities.
// This is not true after https://github.com/elastic/ml-cpp/pull/926. I'll fix and re-enable
// once that change is merged.
//assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
// Assert that the top classes are listed in the order of decreasing scores.
assertThat(Ordering.natural().reverse().isOrdered(classScores), is(true));
}
private <T> void assertEvaluation(String dependentVariable, List<T> dependentVariableValues, String predictedClassField) {