[ML] Validate that AucRoc has the data necessary to be calculated (#63302) (#63454)

This commit is contained in:
Przemysław Witek 2020-10-08 09:52:15 +02:00 committed by GitHub
parent f4530580e7
commit bd761cce1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 74 additions and 126 deletions

View File

@ -114,27 +114,23 @@ public class AucRocMetric implements EvaluationMetric {
} }
private static final ParseField SCORE = new ParseField("score"); private static final ParseField SCORE = new ParseField("score");
private static final ParseField DOC_COUNT = new ParseField("doc_count");
private static final ParseField CURVE = new ParseField("curve"); private static final ParseField CURVE = new ParseField("curve");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>( new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2])); "auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
static { static {
PARSER.declareDouble(constructorArg(), SCORE); PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareLong(constructorArg(), DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
} }
private final double score; private final double score;
private final long docCount;
private final List<AucRocPoint> curve; private final List<AucRocPoint> curve;
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) { public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score; this.score = score;
this.docCount = docCount;
this.curve = curve; this.curve = curve;
} }
@ -147,10 +143,6 @@ public class AucRocMetric implements EvaluationMetric {
return score; return score;
} }
public long getDocCount() {
return docCount;
}
public List<AucRocPoint> getCurve() { public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve); return curve == null ? null : Collections.unmodifiableList(curve);
} }
@ -159,7 +151,6 @@ public class AucRocMetric implements EvaluationMetric {
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(SCORE.getPreferredName(), score); builder.field(SCORE.getPreferredName(), score);
builder.field(DOC_COUNT.getPreferredName(), docCount);
if (curve != null && curve.isEmpty() == false) { if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve); builder.field(CURVE.getPreferredName(), curve);
} }
@ -173,13 +164,12 @@ public class AucRocMetric implements EvaluationMetric {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return score == that.score return score == that.score
&& docCount == that.docCount
&& Objects.equals(curve, that.curve); && Objects.equals(curve, that.curve);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(score, docCount, curve); return Objects.hash(score, curve);
} }
@Override @Override

View File

@ -200,6 +200,7 @@ import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.anyOf;
@ -1931,18 +1932,17 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
createIndex(indexName, mappingForClassification()); createIndex(indexName, mappingForClassification());
BulkRequest regressionBulk = new BulkRequest() BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "cat", "cat", 0.9)) .add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", 0.85)) .add(docForClassification(indexName, "cat", "cat", "dog", "ant"))
.add(docForClassification(indexName, "cat", "cat", 0.95)) .add(docForClassification(indexName, "cat", "cat", "horse", "dog"))
.add(docForClassification(indexName, "cat", "dog", 0.4)) .add(docForClassification(indexName, "cat", "dog", "cat", "mule"))
.add(docForClassification(indexName, "cat", "fish", 0.35)) .add(docForClassification(indexName, "cat", "fish", "cat", "dog"))
.add(docForClassification(indexName, "dog", "cat", 0.5)) .add(docForClassification(indexName, "dog", "cat", "dog", "mule"))
.add(docForClassification(indexName, "dog", "dog", 0.4)) .add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", 0.35)) .add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "dog", "dog", 0.6)) .add(docForClassification(indexName, "dog", "dog", "cat", "ant"))
.add(docForClassification(indexName, "ant", "cat", 0.1)); .add(docForClassification(indexName, "ant", "cat", "ant", "wasp"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
{ // AucRoc { // AucRoc
@ -1957,8 +1957,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9)); assertThat(aucRocResult.getScore(), closeTo(0.6425, 1e-9));
assertThat(aucRocResult.getDocCount(), equalTo(5L));
assertNotNull(aucRocResult.getCurve()); assertNotNull(aucRocResult.getCurve());
} }
{ // Accuracy { // Accuracy
@ -2173,21 +2172,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.endObject(); .endObject();
} }
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) { private static IndexRequest docForClassification(String indexName,
String actualClass,
String... topPredictedClasses) {
assert topPredictedClasses.length > 0;
return new IndexRequest() return new IndexRequest()
.index(indexName) .index(indexName)
.source(XContentType.JSON, .source(XContentType.JSON,
actualClassField, actualClass, actualClassField, actualClass,
predictedClassField, predictedClass, predictedClassField, topPredictedClasses[0],
topClassesField, Arrays.asList( topClassesField, IntStream.range(0, topPredictedClasses.length)
new HashMap<String, Object>() {{ // Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
put("class_name", predictedClass); .mapToObj(i -> new HashMap<String, Object>() {{
put("class_probability", p); put("class_name", topPredictedClasses[i]);
}}, put("class_probability", 1.0 / (2 << i));
new HashMap<String, Object>() {{ }})
put("class_name", "other"); .collect(Collectors.toList()));
put("class_probability", 1 - p);
}}));
} }
private static final String actualRegression = "regression_actual"; private static final String actualRegression = "regression_actual";

View File

@ -201,7 +201,6 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
import org.elasticsearch.client.ml.job.results.Influencer; import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
@ -229,8 +228,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.contains;
@ -3463,34 +3465,33 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.endObject() .endObject()
.endObject() .endObject()
.endObject()); .endObject());
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> { BiFunction<String, String[], IndexRequest> indexRequest = (actualClass, topPredictedClasses) -> {
assert topPredictedClasses.length > 0;
return new IndexRequest() return new IndexRequest()
.source(XContentType.JSON, .source(XContentType.JSON,
"actual_class", actualClass, "actual_class", actualClass,
"predicted_class", predictedClass, "predicted_class", topPredictedClasses[0],
"ml.top_classes", Arrays.asList( "ml.top_classes", IntStream.range(0, topPredictedClasses.length)
new HashMap<String, Object>() {{ // Consecutive assigned probabilities are: 0.5, 0.25, 0.125, etc.
put("class_name", predictedClass); .mapToObj(i -> new HashMap<String, Object>() {{
put("class_probability", p); put("class_name", topPredictedClasses[i]);
}}, put("class_probability", 1.0 / (2 << i));
new HashMap<String, Object>() {{ }})
put("class_name", "other"); .collect(toList()));
put("class_probability", 1 - p);
}}));
}; };
BulkRequest bulkRequest = BulkRequest bulkRequest =
new BulkRequest(indexName) new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(indexRequest.apply("cat", "cat", 0.9)) // #0 .add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #0
.add(indexRequest.apply("cat", "cat", 0.9)) // #1 .add(indexRequest.apply("cat", new String[]{"cat", "dog", "ant"})) // #1
.add(indexRequest.apply("cat", "cat", 0.9)) // #2 .add(indexRequest.apply("cat", new String[]{"cat", "horse", "dog"})) // #2
.add(indexRequest.apply("cat", "dog", 0.9)) // #3 .add(indexRequest.apply("cat", new String[]{"dog", "cat", "mule"})) // #3
.add(indexRequest.apply("cat", "fox", 0.9)) // #4 .add(indexRequest.apply("cat", new String[]{"fox", "cat", "dog"})) // #4
.add(indexRequest.apply("dog", "cat", 0.9)) // #5 .add(indexRequest.apply("dog", new String[]{"cat", "dog", "mule"})) // #5
.add(indexRequest.apply("dog", "dog", 0.9)) // #6 .add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #6
.add(indexRequest.apply("dog", "dog", 0.9)) // #7 .add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #7
.add(indexRequest.apply("dog", "dog", 0.9)) // #8 .add(indexRequest.apply("dog", new String[]{"dog", "cat", "ant"})) // #8
.add(indexRequest.apply("ant", "cat", 0.9)); // #9 .add(indexRequest.apply("ant", new String[]{"cat", "ant", "wasp"})); // #9
RestHighLevelClient client = highLevelClient(); RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT); client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT); client.bulk(bulkRequest, RequestOptions.DEFAULT);
@ -3530,7 +3531,6 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10> AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
double aucRocScore = aucRocResult.getScore(); // <11> double aucRocScore = aucRocResult.getScore(); // <11>
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
// end::evaluate-data-frame-results-classification // end::evaluate-data-frame-results-classification
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
@ -3565,8 +3565,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
assertThat(otherClassesCount, equalTo(0L)); assertThat(otherClassesCount, equalTo(0L));
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocScore, equalTo(0.2625)); assertThat(aucRocScore, closeTo(0.6425, 1e-9));
assertThat(aucRocDocCount, equalTo(5L));
} }
} }

View File

@ -31,7 +31,6 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
public static AucRocMetric.Result randomResult() { public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result( return new AucRocMetric.Result(
randomDouble(), randomDouble(),
randomLong(),
Stream Stream
.generate(AucRocMetricAucRocPointTests::randomPoint) .generate(AucRocMetricAucRocPointTests::randomPoint)
.limit(randomIntBetween(1, 10)) .limit(randomIntBetween(1, 10))

View File

@ -121,7 +121,6 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
<9> Fetching the number of classes that were not included in the matrix <9> Fetching the number of classes that were not included in the matrix
<10> Fetching AucRoc metric by name <10> Fetching AucRoc metric by name
<11> Fetching the actual AucRoc score <11> Fetching the actual AucRoc score
<12> Fetching the number of documents that were used in order to calculate AucRoc score
===== Regression ===== Regression

View File

@ -193,10 +193,8 @@ belongs.
`class_name`:::: `class_name`::::
(Required, string) Name of the only class that will be treated as (Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name` negative ("one-vs-all" strategy). All the evaluated documents must have `class_name`
in the list of their top classes will not be taken into account for in the list of their top classes.
evaluation. The number of documents taken into account is returned in the
evaluation result (`auc_roc.doc_count` field).
`include_curve`:::: `include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in (Optional, boolean) Whether or not the curve should be returned in

View File

@ -5,7 +5,6 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -231,26 +230,18 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
public static class Result implements EvaluationMetricResult { public static class Result implements EvaluationMetricResult {
private static final String SCORE = "score"; private static final String SCORE = "score";
private static final String DOC_COUNT = "doc_count";
private static final String CURVE = "curve"; private static final String CURVE = "curve";
private final double score; private final double score;
private final Long docCount;
private final List<AucRocPoint> curve; private final List<AucRocPoint> curve;
public Result(double score, Long docCount, List<AucRocPoint> curve) { public Result(double score, List<AucRocPoint> curve) {
this.score = score; this.score = score;
this.docCount = docCount;
this.curve = Objects.requireNonNull(curve); this.curve = Objects.requireNonNull(curve);
} }
public Result(StreamInput in) throws IOException { public Result(StreamInput in) throws IOException {
this.score = in.readDouble(); this.score = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.docCount = in.readOptionalLong();
} else {
this.docCount = null;
}
this.curve = in.readList(AucRocPoint::new); this.curve = in.readList(AucRocPoint::new);
} }
@ -258,10 +249,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
return score; return score;
} }
public Long getDocCount() {
return docCount;
}
public List<AucRocPoint> getCurve() { public List<AucRocPoint> getCurve() {
return Collections.unmodifiableList(curve); return Collections.unmodifiableList(curve);
} }
@ -279,9 +266,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score); out.writeDouble(score);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeOptionalLong(docCount);
}
out.writeList(curve); out.writeList(curve);
} }
@ -289,9 +273,6 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(SCORE, score); builder.field(SCORE, score);
if (docCount != null) {
builder.field(DOC_COUNT, docCount);
}
if (curve.isEmpty() == false) { if (curve.isEmpty() == false) {
builder.field(CURVE, curve); builder.field(CURVE, curve);
} }
@ -305,13 +286,12 @@ public abstract class AbstractAucRoc implements EvaluationMetric {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return score == that.score return score == that.score
&& Objects.equals(docCount, that.docCount)
&& Objects.equals(curve, that.curve); && Objects.equals(curve, that.curve);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(score, docCount, curve); return Objects.hash(score, curve);
} }
} }
} }

View File

@ -183,42 +183,39 @@ public class AucRoc extends AbstractAucRoc {
Filter classAgg = aggs.get(TRUE_AGG_NAME); Filter classAgg = aggs.get(TRUE_AGG_NAME);
Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME); Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME);
Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME); Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (classAgg.getDocCount() == 0) { if (classAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]", "[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getActualField(), className); getName(), fields.get().getActualField(), className);
} }
if (classNestedFilter.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
}
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (restAgg.getDocCount() == 0) { if (restAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have a different value than [{}]", "[{}] requires at least one [{}] to have a different value than [{}]",
getName(), fields.get().getActualField(), className); getName(), fields.get().getActualField(), className);
} }
if (restNestedFilter.getDocCount() == 0) { long filteredDocCount = classNestedFilter.getDocCount() + restNestedFilter.getDocCount();
long totalDocCount = classAgg.getDocCount() + restAgg.getDocCount();
if (filteredDocCount < totalDocCount) {
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]", "[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). "
getName(), fields.get().getPredictedClassField(), className); + "This is probably caused by the {} value being less than the total number of actual classes in the dataset.",
getName(), className, fields.get().getPredictedClassField(), filteredDocCount, totalDocCount,
org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
} }
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);
Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME); Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] fpPercentiles = percentilesArray(restPercentiles); double[] fpPercentiles = percentilesArray(restPercentiles);
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve); double aucRocScore = calculateAucScore(aucRocCurve);
result.set( result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
new Result(
aucRocScore,
classNestedFilter.getDocCount() + restNestedFilter.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
} }
@Override @Override

View File

@ -174,11 +174,7 @@ public class AucRoc extends AbstractAucRoc {
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve); double aucRocScore = calculateAucScore(aucRocCurve);
result.set( result.set(new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()));
new Result(
aucRocScore,
classAgg.getDocCount() + restAgg.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
} }
@Override @Override

View File

@ -20,13 +20,12 @@ public class AucRocResultTests extends AbstractWireSerializingTestCase<Result> {
public static Result createRandom() { public static Result createRandom() {
double score = randomDoubleBetween(0.0, 1.0, true); double score = randomDoubleBetween(0.0, 1.0, true);
Long docCount = randomBoolean() ? randomLong() : null;
List<AucRocPoint> curve = List<AucRocPoint> curve =
Stream Stream
.generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble())) .generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble()))
.limit(randomIntBetween(0, 20)) .limit(randomIntBetween(0, 20))
.collect(Collectors.toList()); .collect(Collectors.toList());
return new Result(score, docCount, curve); return new Result(score, curve);
} }
@Override @Override

View File

@ -167,14 +167,12 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
public void testEvaluate_AucRoc_DoNotIncludeCurve() { public void testEvaluate_AucRoc_DoNotIncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(false); AucRoc.Result aucrocResult = evaluateAucRoc(false);
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
assertThat(aucrocResult.getCurve(), hasSize(0)); assertThat(aucrocResult.getCurve(), hasSize(0));
} }
public void testEvaluate_AucRoc_IncludeCurve() { public void testEvaluate_AucRoc_IncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(true); AucRoc.Result aucrocResult = evaluateAucRoc(true);
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001))); assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0))); assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
} }

View File

@ -981,7 +981,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1); AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1);
assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName())); assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L)));
assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0))); assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0)));
} }

View File

@ -207,7 +207,6 @@ setup:
} }
} }
- match: { outlier_detection.auc_roc.score: 0.9899 } - match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_false: outlier_detection.auc_roc.curve - is_false: outlier_detection.auc_roc.curve
--- ---
@ -228,7 +227,6 @@ setup:
} }
} }
- match: { outlier_detection.auc_roc.score: 0.9899 } - match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_false: outlier_detection.auc_roc.curve - is_false: outlier_detection.auc_roc.curve
--- ---
@ -249,7 +247,6 @@ setup:
} }
} }
- match: { outlier_detection.auc_roc.score: 0.9899 } - match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_true: outlier_detection.auc_roc.curve - is_true: outlier_detection.auc_roc.curve
--- ---
@ -415,7 +412,6 @@ setup:
} }
} }
- is_true: outlier_detection.auc_roc.score - is_true: outlier_detection.auc_roc.score
- is_true: outlier_detection.auc_roc.doc_count
- is_true: outlier_detection.precision.0\.25 - is_true: outlier_detection.precision.0\.25
- is_true: outlier_detection.precision.0\.5 - is_true: outlier_detection.precision.0\.5
- is_true: outlier_detection.precision.0\.75 - is_true: outlier_detection.precision.0\.75
@ -689,7 +685,7 @@ setup:
--- ---
"Test classification auc_roc given predicted_class_field is never equal to mouse": "Test classification auc_roc given predicted_class_field is never equal to mouse":
- do: - do:
catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/ catch: /\[auc_roc\] requires that \[mouse\] appears as one of the \[ml.top_classes.class_name\] for every document \(appeared in 0 out of 8\)./
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
{ {
@ -726,7 +722,6 @@ setup:
} }
} }
- match: { classification.auc_roc.score: 0.8050111095212122 } - match: { classification.auc_roc.score: 0.8050111095212122 }
- match: { classification.auc_roc.doc_count: 8 }
- is_false: classification.auc_roc.curve - is_false: classification.auc_roc.curve
--- ---
"Test classification auc_roc with default top_classes_field": "Test classification auc_roc with default top_classes_field":
@ -747,7 +742,6 @@ setup:
} }
} }
- match: { classification.auc_roc.score: 0.8050111095212122 } - match: { classification.auc_roc.score: 0.8050111095212122 }
- match: { classification.auc_roc.doc_count: 8 }
- is_false: classification.auc_roc.curve - is_false: classification.auc_roc.curve
--- ---
"Test classification accuracy with missing predicted_field": "Test classification accuracy with missing predicted_field":