A few cleanups in evaluation tests (#49791) (#49794)

This commit is contained in:
Przemysław Witek 2019-12-03 15:48:39 +01:00 committed by GitHub
parent fbb92f527a
commit a3f88595d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 276 additions and 241 deletions

View File

@ -33,9 +33,6 @@ import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.GetIndexRequest;
import org.elasticsearch.client.ml.CloseJobRequest;
import org.elasticsearch.client.ml.CloseJobResponse;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
import org.elasticsearch.client.ml.DeleteCalendarJobRequest;
import org.elasticsearch.client.ml.DeleteCalendarRequest;
@ -48,8 +45,11 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteJobResponse;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.FindFileStructureRequest;
import org.elasticsearch.client.ml.FindFileStructureResponse;
import org.elasticsearch.client.ml.FlushJobRequest;
@ -135,8 +135,6 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
@ -1852,9 +1850,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
accuracyResult.getActualClasses(),
equalTo(
Arrays.asList(
new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly
// 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("cat", 5, 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75),
// no examples labeled as "ant" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
}
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
@ -1876,20 +1877,29 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
mcmResult.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass(
new MulticlassConfusionMatrixMetric.ActualClass(
"ant",
1L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 0L)),
0L),
new ActualClass(
new MulticlassConfusionMatrixMetric.ActualClass(
"cat",
5L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)),
1L),
new ActualClass(
new MulticlassConfusionMatrixMetric.ActualClass(
"dog",
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)),
0L))));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
}
@ -1912,8 +1922,20 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
mcmResult.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
new MulticlassConfusionMatrixMetric.ActualClass(
"cat",
5L,
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)),
1L),
new MulticlassConfusionMatrixMetric.ActualClass(
"dog",
4L,
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)),
0L)
)));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
}

View File

@ -20,36 +20,56 @@ package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetricResultTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<EvaluateDataFrameResponse> {
public static EvaluateDataFrameResponse randomResponse() {
List<EvaluationMetric.Result> metrics = new ArrayList<>();
if (randomBoolean()) {
metrics.add(AucRocMetricResultTests.randomResult());
String evaluationName = randomFrom(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME);
List<EvaluationMetric.Result> metrics;
switch (evaluationName) {
case BinarySoftClassification.NAME:
metrics = randomSubsetOf(
Arrays.asList(
AucRocMetricResultTests.randomResult(),
PrecisionMetricResultTests.randomResult(),
RecallMetricResultTests.randomResult(),
ConfusionMatrixMetricResultTests.randomResult()));
break;
case Regression.NAME:
metrics = randomSubsetOf(
Arrays.asList(
MeanSquaredErrorMetricResultTests.randomResult(),
RSquaredMetricResultTests.randomResult()));
break;
case Classification.NAME:
metrics = randomSubsetOf(
Arrays.asList(
AccuracyMetricResultTests.randomResult(),
MulticlassConfusionMatrixMetricResultTests.randomResult()));
break;
default:
throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement");
}
if (randomBoolean()) {
metrics.add(PrecisionMetricResultTests.randomResult());
}
if (randomBoolean()) {
metrics.add(RecallMetricResultTests.randomResult());
}
if (randomBoolean()) {
metrics.add(ConfusionMatrixMetricResultTests.randomResult());
}
if (randomBoolean()) {
metrics.add(MeanSquaredErrorMetricResultTests.randomResult());
}
return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics);
return new EvaluateDataFrameResponse(evaluationName, metrics);
}
@Override

View File

@ -31,15 +31,14 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class AccuracyMetricResultTests extends AbstractXContentTestCase<AccuracyMetric.Result> {
public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected AccuracyMetric.Result createTestInstance() {
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
@ -52,8 +51,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Accuracy
}
@Override
protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException {
return AccuracyMetric.Result.fromXContent(parser);
protected Result createTestInstance() {
return randomResult();
}
@Override
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);
}
@Override

View File

@ -38,7 +38,10 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
static Classification createRandom() {
List<EvaluationMetric> metrics =
randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom()));
randomSubsetOf(
Arrays.asList(
AccuracyMetricTests.createRandom(),
MulticlassConfusionMatrixMetricTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}

View File

@ -40,8 +40,7 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected Result createTestInstance() {
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
@ -60,6 +59,11 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
}
@Override
protected Result createTestInstance() {
return randomResult();
}
@Override
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
@ -27,11 +26,11 @@ import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint;
import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricAucRocPointTests.randomPoint;
public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetric.Result> {
static AucRocMetric.Result randomResult() {
public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result(
randomDouble(),
Stream

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
@ -27,11 +26,11 @@ import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix;
import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix;
public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase<ConfusionMatrixMetric.Result> {
static ConfusionMatrixMetric.Result randomResult() {
public static ConfusionMatrixMetric.Result randomResult() {
return new ConfusionMatrixMetric.Result(
Stream
.generate(() -> randomConfusionMatrix())

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
@ -29,7 +28,7 @@ import java.util.stream.Stream;
public class PrecisionMetricResultTests extends AbstractXContentTestCase<PrecisionMetric.Result> {
static PrecisionMetric.Result randomResult() {
public static PrecisionMetric.Result randomResult() {
return new PrecisionMetric.Result(
Stream
.generate(() -> randomDouble())

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml;
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
@ -29,7 +28,7 @@ import java.util.stream.Stream;
public class RecallMetricResultTests extends AbstractXContentTestCase<RecallMetric.Result> {
static RecallMetric.Result randomResult() {
public static RecallMetric.Result randomResult() {
return new RecallMetric.Result(
Stream
.generate(() -> randomDouble())

View File

@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import java.util.Collections;
import java.util.List;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public final class MockAggregations {
public static Terms mockTerms(String name) {
return mockTerms(name, Collections.emptyList(), 0);
}
public static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
Terms agg = mock(Terms.class);
when(agg.getName()).thenReturn(name);
doReturn(buckets).when(agg).getBuckets();
when(agg.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
return agg;
}
public static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
public static Filters mockFilters(String name) {
return mockFilters(name, Collections.emptyList());
}
public static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
Filters agg = mock(Filters.class);
when(agg.getName()).thenReturn(name);
doReturn(buckets).when(agg).getBuckets();
return agg;
}
public static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
public static Filters.Bucket mockFiltersBucket(String key, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getDocCount()).thenReturn(docCount);
return bucket;
}
public static Filter mockFilter(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
public static NumericMetricsAggregation.SingleValue mockSingleValue(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
public static Cardinality mockCardinality(String name, long value) {
Cardinality agg = mock(Cardinality.class);
when(agg.getName()).thenReturn(name);
when(agg.getValue()).thenReturn(value);
return agg;
}
public static ExtendedStats mockExtendedStats(String name, double variance, long count) {
ExtendedStats agg = mock(ExtendedStats.class);
when(agg.getName()).thenReturn(name);
when(agg.getVariance()).thenReturn(variance);
when(agg.getCount()).thenReturn(count);
return agg;
}
}

View File

@ -8,17 +8,15 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@ -48,9 +46,9 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"),
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockTerms("classification_classes"),
mockSingleValue("classification_overall_accuracy", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
@ -62,16 +60,16 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess_GivenMissingAgg() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockTerms("classification_classes"),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("classification_overall_accuracy", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
@ -81,32 +79,19 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess_GivenAggOfWrongType() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"),
createTermsAgg("classification_overall_accuracy")
mockTerms("classification_classes"),
mockTerms("classification_overall_accuracy")
));
Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("classification_classes", 1.0),
createSingleMetricAgg("classification_overall_accuracy", 0.8123)
mockSingleValue("classification_classes", 1.0),
mockSingleValue("classification_overall_accuracy", 0.8123)
));
Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
}
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
private static Terms createTermsAgg(String name) {
Terms agg = mock(Terms.class);
when(agg.getName()).thenReturn(name);
return agg;
}
}

View File

@ -52,7 +52,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
public static Classification createRandom() {
List<ClassificationMetric> metrics =
randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom()));
randomSubsetOf(
Arrays.asList(
AccuracyTests.createRandom(),
MulticlassConfusionMatrixTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}

View File

@ -10,9 +10,6 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
@ -21,15 +18,17 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
@ -77,7 +76,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
assertThat(aggs, is(not(empty())));
assertThat(confusionMatrix.getResult(), equalTo(Optional.empty()));
assertThat(confusionMatrix.getResult(), isEmpty());
}
public void testEvaluate() {
@ -163,46 +162,4 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
assertThat(result.getOtherActualClassCount(), equalTo(3L));
}
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
Terms aggregation = mock(Terms.class);
when(aggregation.getName()).thenReturn(name);
doReturn(buckets).when(aggregation).getBuckets();
when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
return aggregation;
}
private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
Filters aggregation = mock(Filters.class);
when(aggregation.getName()).thenReturn(name);
doReturn(buckets).when(aggregation).getBuckets();
return aggregation;
}
private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters.Bucket mockFiltersBucket(String key, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getDocCount()).thenReturn(docCount);
return bucket;
}
private static Cardinality mockCardinality(String name, long value) {
Cardinality aggregation = mock(Cardinality.class);
when(aggregation.getName()).thenReturn(name);
when(aggregation.getValue()).thenReturn(value);
return aggregation;
}
}

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -17,9 +16,8 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
@ -44,8 +42,8 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("regression_mean_squared_error", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("regression_mean_squared_error", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
MeanSquaredError mse = new MeanSquaredError();
@ -58,7 +56,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
MeanSquaredError mse = new MeanSquaredError();
@ -67,11 +65,4 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
EvaluationMetricResult result = mse.getResult().get();
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
}

View File

@ -9,8 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +16,9 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockExtendedStats;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
@ -45,10 +43,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 10_111),
createExtendedStatsAgg("extended_stats_actual", 155.23, 1000),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("residual_sum_of_squares", 10_111),
mockExtendedStats("extended_stats_actual", 155.23, 1000),
mockExtendedStats("some_other_extended_stats",99.1, 10_000),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
@ -61,10 +59,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluateWithZeroCount() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 0),
createExtendedStatsAgg("extended_stats_actual", 0.0, 0),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("residual_sum_of_squares", 0),
mockExtendedStats("extended_stats_actual", 0.0, 0),
mockExtendedStats("some_other_extended_stats",99.1, 10_000),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
@ -76,10 +74,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluateWithSingleCountZeroVariance() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 1),
createExtendedStatsAgg("extended_stats_actual", 0.0, 1),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("residual_sum_of_squares", 1),
mockExtendedStats("extended_stats_actual", 0.0, 1),
mockExtendedStats("some_other_extended_stats",99.1, 10_000),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
@ -91,7 +89,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
@ -103,8 +101,8 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate_GivenMissingExtendedStatsAgg() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
createSingleMetricAgg("residual_sum_of_squares", 0.2377)
mockSingleValue("some_other_single_metric_agg", 0.2377),
mockSingleValue("residual_sum_of_squares", 0.2377)
));
RSquared rSquared = new RSquared();
@ -116,8 +114,8 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
createExtendedStatsAgg("extended_stats_actual",100, 50)
mockSingleValue("some_other_single_metric_agg", 0.2377),
mockExtendedStats("extended_stats_actual",100, 50)
));
RSquared rSquared = new RSquared();
@ -126,19 +124,4 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
private static ExtendedStats createExtendedStatsAgg(String name, double variance, long count) {
ExtendedStats agg = mock(ExtendedStats.class);
when(agg.getName()).thenReturn(name);
when(agg.getVariance()).thenReturn(variance);
when(agg.getCount()).thenReturn(count);
return agg;
}
}

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +17,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionMatrix> {
@ -50,14 +48,14 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("confusion_matrix_at_0.25_TP", 1L),
createFilterAgg("confusion_matrix_at_0.25_FP", 2L),
createFilterAgg("confusion_matrix_at_0.25_TN", 3L),
createFilterAgg("confusion_matrix_at_0.25_FN", 4L),
createFilterAgg("confusion_matrix_at_0.5_TP", 5L),
createFilterAgg("confusion_matrix_at_0.5_FP", 6L),
createFilterAgg("confusion_matrix_at_0.5_TN", 7L),
createFilterAgg("confusion_matrix_at_0.5_FN", 8L)
mockFilter("confusion_matrix_at_0.25_TP", 1L),
mockFilter("confusion_matrix_at_0.25_FP", 2L),
mockFilter("confusion_matrix_at_0.25_TN", 3L),
mockFilter("confusion_matrix_at_0.25_FN", 4L),
mockFilter("confusion_matrix_at_0.5_TP", 5L),
mockFilter("confusion_matrix_at_0.5_FP", 6L),
mockFilter("confusion_matrix_at_0.5_TN", 7L),
mockFilter("confusion_matrix_at_0.5_FN", 8L)
));
ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
@ -66,11 +64,4 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
assertThat(Strings.toString(result), equalTo(expected));
}
private static Filter createFilterAgg(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
}

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +17,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
@ -50,12 +48,12 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("precision_at_0.25_TP", 1L),
createFilterAgg("precision_at_0.25_FP", 4L),
createFilterAgg("precision_at_0.5_TP", 3L),
createFilterAgg("precision_at_0.5_FP", 1L),
createFilterAgg("precision_at_0.75_TP", 5L),
createFilterAgg("precision_at_0.75_FP", 0L)
mockFilter("precision_at_0.25_TP", 1L),
mockFilter("precision_at_0.25_FP", 4L),
mockFilter("precision_at_0.5_TP", 3L),
mockFilter("precision_at_0.5_FP", 1L),
mockFilter("precision_at_0.75_TP", 5L),
mockFilter("precision_at_0.75_FP", 0L)
));
Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
@ -67,8 +65,8 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
public void testEvaluate_GivenZeroTpAndFp() {
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("precision_at_1.0_TP", 0L),
createFilterAgg("precision_at_1.0_FP", 0L)
mockFilter("precision_at_1.0_TP", 0L),
mockFilter("precision_at_1.0_FP", 0L)
));
Precision precision = new Precision(Arrays.asList(1.0));
@ -77,11 +75,4 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
String expected = "{\"1.0\":0.0}";
assertThat(Strings.toString(result), equalTo(expected));
}
private static Filter createFilterAgg(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
}

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +17,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RecallTests extends AbstractSerializingTestCase<Recall> {
@ -50,12 +48,12 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("recall_at_0.25_TP", 1L),
createFilterAgg("recall_at_0.25_FN", 4L),
createFilterAgg("recall_at_0.5_TP", 3L),
createFilterAgg("recall_at_0.5_FN", 1L),
createFilterAgg("recall_at_0.75_TP", 5L),
createFilterAgg("recall_at_0.75_FN", 0L)
mockFilter("recall_at_0.25_TP", 1L),
mockFilter("recall_at_0.25_FN", 4L),
mockFilter("recall_at_0.5_TP", 3L),
mockFilter("recall_at_0.5_FN", 1L),
mockFilter("recall_at_0.75_TP", 5L),
mockFilter("recall_at_0.75_FN", 0L)
));
Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
@ -67,8 +65,8 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
public void testEvaluate_GivenZeroTpAndFp() {
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("recall_at_1.0_TP", 0L),
createFilterAgg("recall_at_1.0_FN", 0L)
mockFilter("recall_at_1.0_TP", 0L),
mockFilter("recall_at_1.0_FN", 0L)
));
Recall recall = new Recall(Arrays.asList(1.0));
@ -77,11 +75,4 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
String expected = "{\"1.0\":0.0}";
assertThat(Strings.toString(result), equalTo(expected));
}
private static Filter createFilterAgg(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
}