[ML] Validate that AucRoc has the data necessary to be calculated (#63302) (#63454)

This commit is contained in:
Przemysław Witek 2020-10-08 09:52:15 +02:00 committed by GitHub
parent f4530580e7
commit bd761cce1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 74 additions and 126 deletions

View File

@ -114,27 +114,23 @@ public class AucRocMetric implements EvaluationMetric {
}
private static final ParseField SCORE = new ParseField("score");
private static final ParseField DOC_COUNT = new ParseField("doc_count");
private static final ParseField CURVE = new ParseField("curve");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
"auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareLong(constructorArg(), DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}
private final double score;
private final long docCount;
private final List<AucRocPoint> curve;
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = curve;
}
@ -147,10 +143,6 @@ public class AucRocMetric implements EvaluationMetric {
return score;
}
public long getDocCount() {
return docCount;
}
public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}
@ -159,7 +151,6 @@ public class AucRocMetric implements EvaluationMetric {
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
builder.field(DOC_COUNT.getPreferredName(), docCount);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
@ -173,13 +164,12 @@ public class AucRocMetric implements EvaluationMetric {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& docCount == that.docCount
&& Objects.equals(curve, that.curve);
}
@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
return Objects.hash(score, curve);
}
@Override

View File

@ -200,6 +200,7 @@ import java.util.Locale;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
@ -1931,18 +1932,17 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
createIndex(indexName, mappingForClassification());
BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "cat", "cat", 0.9))
.add(docForClassification(indexName, "cat", "cat", 0.85))
.add(docForClassification(indexName, "cat", "cat", 0.95))
.add(docForClassification(indexName, "cat", "dog", 0.4))
.add(docForClassification(indexName, "cat", "fish", 0.35))
.add(docForClassification(indexName, "dog", "cat", 0.5))
.add(docForClassification(indexName, "dog", "dog", 0.4))
.add(docForClassification(indexName, "dog", "dog", 0.35))
.add(docForClassification(indexName, "dog", "dog", 0.6))
.add(docForClassification(indexName, "ant", "cat", 0.1));
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
.add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
.add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
.add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
{ // AucRoc
@ -1957,8 +1957,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
assertThat(aucRocResult.getDocCount(), equalTo(5L));
assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
assertNotNull(aucRocResult.getCurve());
}
{ // Accuracy
@ -2173,21 +2172,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.endObject();
}
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
private static IndexRequest docForClassification(String indexName,
String actualClass,
String... topPredictedClasses) {
assert topPredictedClasses.length > 0;
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON,
actualClassField, actualClass,
predictedClassField, predictedClass,
topClassesField, Arrays.asList(
new HashMap<String, Object>() {{
put("class_name", predictedClass);
put("class_probability", p);
}},
new HashMap<String, Object>() {{
put("class_name", "other");
put("class_probability", 1 - p);
}}));
predictedClassField, topPredictedClasses[0],
topClassesField, IntStream.range(0, topPredictedClasses.length)
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
.mapToObj(i -> new HashMap<String, Object>() {{
put("class_name", topPredictedClasses[i]);
put("class_probability", 1.0 / (2 << i));
}})
.collect(Collectors.toList()));
}
private static final String actualRegression = "regression_actual";

View File

@ -201,7 +201,6 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -229,8 +228,11 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
@ -3463,34 +3465,33 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.endObject()
.endObject()
.endObject());
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
assert topPredictedClasses.length > 0;
return new IndexRequest()
.source(XContentType.JSON,
"actual_class", actualClass,
"predicted_class", predictedClass,
"ml.top_classes", Arrays.asList(
new HashMap<String, Object>() {{
put("class_name", predictedClass);
put("class_probability", p);
}},
new HashMap<String, Object>() {{
put("class_name", "other");
put("class_probability", 1 - p);
}}));
"predicted_class", topPredictedClasses[0],
"ml.top_classes", IntStream.range(0, topPredictedClasses.length)
// Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
.mapToObj(i -> new HashMap<String, Object>() {{
put("class_name", topPredictedClasses[i]);
put("class_probability", 1.0 / (2 << i));
}})
.collect(toList()));
};
BulkRequest bulkRequest =
new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
.add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
.add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
.add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
.add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
.add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
.add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
.add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT);
@ -3530,7 +3531,6 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
double aucRocScore = aucRocResult.getScore(); // <11>
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
// end::evaluate-data-frame-results-classification
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
@ -3565,8 +3565,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
assertThat(otherClassesCount, equalTo(0L));
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocScore, equalTo(0.2625));
assertThat(aucRocDocCount, equalTo(5L));
assertThat(aucRocScore, closeTo(0.6425, 1e-9));
}
}

View File

@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result(
randomDouble(),
randomLong(),
Stream
.generate(AucRocMetricAucRocPointTests::randomPoint)
.limit(randomIntBetween(1, 10))

View File

@ -121,7 +121,6 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
<9> Fetching the number of classes that were not included in the matrix
<10> Fetching AucRoc metric by name
<11> Fetching the actual AucRoc score
<12> Fetching the number of documents that were used in order to calculate AucRoc score
===== Regression

View File

@ -193,10 +193,8 @@ belongs.
`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for
evaluation. The number of documents taken into account is returned in the
evaluation result (`auc_roc.doc_count` field).
negative ("one-vs-all" strategy). All the evaluated documents must have `class_name`
in the list of their top classes.
`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -231,26 +230,18 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
public static class Result implements EvaluationMetricResult {
private static final String SCORE = "score";
private static final String DOC_COUNT = "doc_count";
private static final String CURVE = "curve";
private final double score;
private final Long docCount;
private final List<AucRocPoint> curve;
public Result(double score, Long docCount, List<AucRocPoint> curve) {
public Result(double score, List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = Objects.requireNonNull(curve);
}
public Result(StreamInput in) throws IOException {
this.score = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.docCount = in.readOptionalLong();
} else {
this.docCount = null;
}
this.curve = in.readList(AucRocPoint::new);
}
@ -258,10 +249,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
return score;
}
public Long getDocCount() {
return docCount;
}
public List<AucRocPoint> getCurve() {
return Collections.unmodifiableList(curve);
}
@ -279,9 +266,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeOptionalLong(docCount);
}
out.writeList(curve);
}
@ -289,9 +273,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SCORE, score);
if (docCount != null) {
builder.field(DOC_COUNT, docCount);
}
if (curve.isEmpty() == false) {
builder.field(CURVE, curve);
}
@ -305,13 +286,12 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& Objects.equals(docCount, that.docCount)
&& Objects.equals(curve, that.curve);
}
@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
return Objects.hash(score, curve);
}
}
}

View File

@ -183,42 +183,39 @@ public class AucRoc extends AbstractAucRoc {
Filter classAgg = aggs.get(TRUE_AGG_NAME);
Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME);
Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (classAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getActualField(), className);
}
if (classNestedFilter.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
}
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (restAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have a different value than [{}]",
getName(), fields.get().getActualField(), className);
}
if (restNestedFilter.getDocCount() == 0) {
long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount();
if (filteredDocCount < totalDocCount) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
"[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). "
+ "This is probably caused by the {} value being less than the total number of actual classes in the dataset.",
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount,
org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
}
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);
Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] fpPercentiles = percentilesArray(restPercentiles);
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result.set(
new Result(
aucRocScore,
classNestedFilter.getDocCount() + restNestedFilter.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
}
@Override

View File

@ -174,11 +174,7 @@ public class AucRoc extends AbstractAucRoc {
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result.set(
new Result(
aucRocScore,
classAgg.getDocCount() + restAgg.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
}
@Override

View File

@ -20,13 +20,12 @@ public class AucRocResultTests extends AbstractWireSerializingTestCase<Result> {
public static Result createRandom() {
double score = randomDoubleBetween(0.0, 1.0, true);
Long docCount = randomBoolean() ? randomLong() : null;
List<AucRocPoint> curve =
Stream
.generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble()))
.limit(randomIntBetween(0, 20))
.collect(Collectors.toList());
return new Result(score, docCount, curve);
return new Result(score, curve);
}
@Override

View File

@ -167,14 +167,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(false);
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
assertThat(aucrocResult.getCurve(), hasSize(0));
}
public void testEvaluate_AucRoc_IncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(true);
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
}

View File

@ -981,7 +981,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1);
assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L)));
assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0)));
}

View File

@ -207,7 +207,6 @@ setup:
}
}
- match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_false: outlier_detection.auc_roc.curve
---
@ -228,7 +227,6 @@ setup:
}
}
- match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_false: outlier_detection.auc_roc.curve
---
@ -249,7 +247,6 @@ setup:
}
}
- match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_true: outlier_detection.auc_roc.curve
---
@ -415,7 +412,6 @@ setup:
}
}
- is_true: outlier_detection.auc_roc.score
- is_true: outlier_detection.auc_roc.doc_count
- is_true: outlier_detection.precision.0\.25
- is_true: outlier_detection.precision.0\.5
- is_true: outlier_detection.precision.0\.75
@ -689,7 +685,7 @@ setup:
---
"Test classification auc_roc given predicted_class_field is never equal to mouse":
- do:
catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/
catch: /\[auc_roc\] requires that \[mouse\] appears as one of the \[ml.top_classes.class_name\] for every document \(appeared in 0 out of 8\)./
ml.evaluate_data_frame:
body: >
{
@ -726,7 +722,6 @@ setup:
}
}
- match: { classification.auc_roc.score: 0.8050111095212122 }
- match: { classification.auc_roc.doc_count: 8 }
- is_false: classification.auc_roc.curve
---
"Test classification auc_roc with default top_classes_field":
@ -747,7 +742,6 @@ setup:
}
}
- match: { classification.auc_roc.score: 0.8050111095212122 }
- match: { classification.auc_roc.doc_count: 8 }
- is_false: classification.auc_roc.curve
---
"Test classification accuracy with missing predicted_field":