[7.x][ML] Assert top classes are ordered by score (#51028)
Backport #51003.
This commit is contained in:
parent
b345c7ff31
commit
32ec934b15
|
@ -5,8 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
// Pending fix
|
||||
//import com.google.common.collect.Ordering;
|
||||
import com.google.common.collect.Ordering;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
|
||||
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
|
||||
|
@ -582,9 +581,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(topClasses, hasSize(numTopClasses));
|
||||
List<T> classNames = new ArrayList<>(topClasses.size());
|
||||
List<Double> classProbabilities = new ArrayList<>(topClasses.size());
|
||||
List<Double> classScores = new ArrayList<>(topClasses.size());
|
||||
for (Map<String, Object> topClass : topClasses) {
|
||||
classNames.add(getFieldValue(topClass, "class_name"));
|
||||
classProbabilities.add(getFieldValue(topClass, "class_probability"));
|
||||
classScores.add(getFieldValue(topClass, "class_score"));
|
||||
}
|
||||
// Assert that all the predicted class names come from the set of dependent variable values.
|
||||
classNames.forEach(className -> assertThat(className, is(in(dependentVariableValues))));
|
||||
|
@ -592,10 +593,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
|
||||
// Assert that all the class probabilities lie within [0, 1] interval.
|
||||
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||
// Assert that the top classes are listed in the order of decreasing probabilities.
|
||||
// This is not true after https://github.com/elastic/ml-cpp/pull/926. I'll fix and re-enable
|
||||
// once that change is merged.
|
||||
//assertThat(Ordering.natural().reverse().isOrdered(classProbabilities), is(true));
|
||||
// Assert that the top classes are listed in the order of decreasing scores.
|
||||
assertThat(Ordering.natural().reverse().isOrdered(classScores), is(true));
|
||||
}
|
||||
|
||||
private <T> void assertEvaluation(String dependentVariable, List<T> dependentVariableValues, String predictedClassField) {
|
||||
|
|
Loading…
Reference in New Issue