A few cleanups in evaluation tests (#49791) (#49794)

This commit is contained in:
Przemysław Witek 2019-12-03 15:48:39 +01:00 committed by GitHub
parent fbb92f527a
commit a3f88595d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 276 additions and 241 deletions

View File

@ -33,9 +33,6 @@ import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.GetIndexRequest; import org.elasticsearch.client.indices.GetIndexRequest;
import org.elasticsearch.client.ml.CloseJobRequest; import org.elasticsearch.client.ml.CloseJobRequest;
import org.elasticsearch.client.ml.CloseJobResponse; import org.elasticsearch.client.ml.CloseJobResponse;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest;
import org.elasticsearch.client.ml.DeleteCalendarRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest;
@ -48,8 +45,11 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteJobResponse;
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureRequest;
import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FindFileStructureResponse;
import org.elasticsearch.client.ml.FlushJobRequest; import org.elasticsearch.client.ml.FlushJobRequest;
@ -135,8 +135,6 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
@ -1852,9 +1850,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
accuracyResult.getActualClasses(), accuracyResult.getActualClasses(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly // 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly new AccuracyMetric.ActualClass("cat", 5, 0.6),
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly // 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75),
// no examples labeled as "ant" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
} }
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead { // No size provided for MulticlassConfusionMatrixMetric, default used instead
@ -1876,20 +1877,29 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
mcmResult.getConfusionMatrix(), mcmResult.getConfusionMatrix(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
new ActualClass( new MulticlassConfusionMatrixMetric.ActualClass(
"ant", "ant",
1L, 1L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 0L)),
0L), 0L),
new ActualClass( new MulticlassConfusionMatrixMetric.ActualClass(
"cat", "cat",
5L, 5L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)),
1L), 1L),
new ActualClass( new MulticlassConfusionMatrixMetric.ActualClass(
"dog", "dog",
4L, 4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)),
0L)))); 0L))));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L)); assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
} }
@ -1912,8 +1922,20 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
mcmResult.getConfusionMatrix(), mcmResult.getConfusionMatrix(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L), new MulticlassConfusionMatrixMetric.ActualClass(
new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L) "cat",
5L,
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)),
1L),
new MulticlassConfusionMatrixMetric.ActualClass(
"dog",
4L,
Arrays.asList(
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)),
0L)
))); )));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L)); assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
} }

View File

@ -20,36 +20,56 @@ package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetricResultTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.function.Predicate; import java.util.function.Predicate;
public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<EvaluateDataFrameResponse> { public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<EvaluateDataFrameResponse> {
public static EvaluateDataFrameResponse randomResponse() { public static EvaluateDataFrameResponse randomResponse() {
List<EvaluationMetric.Result> metrics = new ArrayList<>(); String evaluationName = randomFrom(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME);
if (randomBoolean()) { List<EvaluationMetric.Result> metrics;
metrics.add(AucRocMetricResultTests.randomResult()); switch (evaluationName) {
case BinarySoftClassification.NAME:
metrics = randomSubsetOf(
Arrays.asList(
AucRocMetricResultTests.randomResult(),
PrecisionMetricResultTests.randomResult(),
RecallMetricResultTests.randomResult(),
ConfusionMatrixMetricResultTests.randomResult()));
break;
case Regression.NAME:
metrics = randomSubsetOf(
Arrays.asList(
MeanSquaredErrorMetricResultTests.randomResult(),
RSquaredMetricResultTests.randomResult()));
break;
case Classification.NAME:
metrics = randomSubsetOf(
Arrays.asList(
AccuracyMetricResultTests.randomResult(),
MulticlassConfusionMatrixMetricResultTests.randomResult()));
break;
default:
throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement");
} }
if (randomBoolean()) { return new EvaluateDataFrameResponse(evaluationName, metrics);
metrics.add(PrecisionMetricResultTests.randomResult());
}
if (randomBoolean()) {
metrics.add(RecallMetricResultTests.randomResult());
}
if (randomBoolean()) {
metrics.add(ConfusionMatrixMetricResultTests.randomResult());
}
if (randomBoolean()) {
metrics.add(MeanSquaredErrorMetricResultTests.randomResult());
}
return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics);
} }
@Override @Override

View File

@ -31,15 +31,14 @@ import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
public class AccuracyMetricResultTests extends AbstractXContentTestCase<AccuracyMetric.Result> { public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result> {
@Override @Override
protected NamedXContentRegistry xContentRegistry() { protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
} }
@Override public static Result randomResult() {
protected AccuracyMetric.Result createTestInstance() {
int numClasses = randomIntBetween(2, 100); int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses); List<ActualClass> actualClasses = new ArrayList<>(numClasses);
@ -52,8 +51,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Accuracy
} }
@Override @Override
protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException { protected Result createTestInstance() {
return AccuracyMetric.Result.fromXContent(parser); return randomResult();
}
@Override
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);
} }
@Override @Override

View File

@ -38,7 +38,10 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
static Classification createRandom() { static Classification createRandom() {
List<EvaluationMetric> metrics = List<EvaluationMetric> metrics =
randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom())); randomSubsetOf(
Arrays.asList(
AccuracyMetricTests.createRandom(),
MulticlassConfusionMatrixMetricTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
} }

View File

@ -40,8 +40,7 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
} }
@Override public static Result randomResult() {
protected Result createTestInstance() {
int numClasses = randomIntBetween(2, 100); int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses); List<ActualClass> actualClasses = new ArrayList<>(numClasses);
@ -60,6 +59,11 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null); return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
} }
@Override
protected Result createTestInstance() {
return randomResult();
}
@Override @Override
protected Result doParseInstance(XContentParser parser) throws IOException { protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser); return Result.fromXContent(parser);

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml; package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml; package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
@ -27,11 +26,11 @@ import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint; import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricAucRocPointTests.randomPoint;
public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetric.Result> { public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetric.Result> {
static AucRocMetric.Result randomResult() { public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result( return new AucRocMetric.Result(
randomDouble(), randomDouble(),
Stream Stream

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml; package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml; package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
@ -27,11 +26,11 @@ import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix; import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix;
public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase<ConfusionMatrixMetric.Result> { public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase<ConfusionMatrixMetric.Result> {
static ConfusionMatrixMetric.Result randomResult() { public static ConfusionMatrixMetric.Result randomResult() {
return new ConfusionMatrixMetric.Result( return new ConfusionMatrixMetric.Result(
Stream Stream
.generate(() -> randomConfusionMatrix()) .generate(() -> randomConfusionMatrix())

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml; package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
@ -29,7 +28,7 @@ import java.util.stream.Stream;
public class PrecisionMetricResultTests extends AbstractXContentTestCase<PrecisionMetric.Result> { public class PrecisionMetricResultTests extends AbstractXContentTestCase<PrecisionMetric.Result> {
static PrecisionMetric.Result randomResult() { public static PrecisionMetric.Result randomResult() {
return new PrecisionMetric.Result( return new PrecisionMetric.Result(
Stream Stream
.generate(() -> randomDouble()) .generate(() -> randomDouble())

View File

@ -16,9 +16,8 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml; package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
@ -29,7 +28,7 @@ import java.util.stream.Stream;
public class RecallMetricResultTests extends AbstractXContentTestCase<RecallMetric.Result> { public class RecallMetricResultTests extends AbstractXContentTestCase<RecallMetric.Result> {
static RecallMetric.Result randomResult() { public static RecallMetric.Result randomResult() {
return new RecallMetric.Result( return new RecallMetric.Result(
Stream Stream
.generate(() -> randomDouble()) .generate(() -> randomDouble())

View File

@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import java.util.Collections;
import java.util.List;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public final class MockAggregations {
public static Terms mockTerms(String name) {
return mockTerms(name, Collections.emptyList(), 0);
}
public static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
Terms agg = mock(Terms.class);
when(agg.getName()).thenReturn(name);
doReturn(buckets).when(agg).getBuckets();
when(agg.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
return agg;
}
public static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
public static Filters mockFilters(String name) {
return mockFilters(name, Collections.emptyList());
}
public static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
Filters agg = mock(Filters.class);
when(agg.getName()).thenReturn(name);
doReturn(buckets).when(agg).getBuckets();
return agg;
}
public static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
public static Filters.Bucket mockFiltersBucket(String key, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getDocCount()).thenReturn(docCount);
return bucket;
}
public static Filter mockFilter(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
public static NumericMetricsAggregation.SingleValue mockSingleValue(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
public static Cardinality mockCardinality(String name, long value) {
Cardinality agg = mock(Cardinality.class);
when(agg.getName()).thenReturn(name);
when(agg.getValue()).thenReturn(value);
return agg;
}
public static ExtendedStats mockExtendedStats(String name, double variance, long count) {
ExtendedStats agg = mock(ExtendedStats.class);
when(agg.getName()).thenReturn(name);
when(agg.getVariance()).thenReturn(variance);
when(agg.getCount()).thenReturn(count);
return agg;
}
}

View File

@ -8,17 +8,15 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> { public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@ -48,9 +46,9 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess() { public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"), mockTerms("classification_classes"),
createSingleMetricAgg("classification_overall_accuracy", 0.8123), mockSingleValue("classification_overall_accuracy", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
@ -62,16 +60,16 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess_GivenMissingAgg() { public void testProcess_GivenMissingAgg() {
{ {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"), mockTerms("classification_classes"),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
} }
{ {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("classification_overall_accuracy", 0.8123), mockSingleValue("classification_overall_accuracy", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
@ -81,32 +79,19 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess_GivenAggOfWrongType() { public void testProcess_GivenAggOfWrongType() {
{ {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"), mockTerms("classification_classes"),
createTermsAgg("classification_overall_accuracy") mockTerms("classification_overall_accuracy")
)); ));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
} }
{ {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("classification_classes", 1.0), mockSingleValue("classification_classes", 1.0),
createSingleMetricAgg("classification_overall_accuracy", 0.8123) mockSingleValue("classification_overall_accuracy", 0.8123)
)); ));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
} }
} }
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
private static Terms createTermsAgg(String name) {
Terms agg = mock(Terms.class);
when(agg.getName()).thenReturn(name);
return agg;
}
} }

View File

@ -52,7 +52,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
public static Classification createRandom() { public static Classification createRandom() {
List<ClassificationMetric> metrics = List<ClassificationMetric> metrics =
randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom())); randomSubsetOf(
Arrays.asList(
AccuracyTests.createRandom(),
MulticlassConfusionMatrixTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
} }

View File

@ -10,9 +10,6 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
@ -21,15 +18,17 @@ import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> { public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
@ -77,7 +76,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred"); List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
assertThat(aggs, is(not(empty()))); assertThat(aggs, is(not(empty())));
assertThat(confusionMatrix.getResult(), equalTo(Optional.empty())); assertThat(confusionMatrix.getResult(), isEmpty());
} }
public void testEvaluate() { public void testEvaluate() {
@ -163,46 +162,4 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
assertThat(result.getOtherActualClassCount(), equalTo(3L)); assertThat(result.getOtherActualClassCount(), equalTo(3L));
} }
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
Terms aggregation = mock(Terms.class);
when(aggregation.getName()).thenReturn(name);
doReturn(buckets).when(aggregation).getBuckets();
when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
return aggregation;
}
private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
Filters aggregation = mock(Filters.class);
when(aggregation.getName()).thenReturn(name);
doReturn(buckets).when(aggregation).getBuckets();
return aggregation;
}
private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters.Bucket mockFiltersBucket(String key, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getDocCount()).thenReturn(docCount);
return bucket;
}
private static Cardinality mockCardinality(String name, long value) {
Cardinality aggregation = mock(Cardinality.class);
when(aggregation.getName()).thenReturn(name);
when(aggregation.getValue()).thenReturn(value);
return aggregation;
}
} }

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -17,9 +16,8 @@ import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> { public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
@ -44,8 +42,8 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("regression_mean_squared_error", 0.8123), mockSingleValue("regression_mean_squared_error", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
MeanSquaredError mse = new MeanSquaredError(); MeanSquaredError mse = new MeanSquaredError();
@ -58,7 +56,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
public void testEvaluate_GivenMissingAggs() { public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList( Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
MeanSquaredError mse = new MeanSquaredError(); MeanSquaredError mse = new MeanSquaredError();
@ -67,11 +65,4 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
EvaluationMetricResult result = mse.getResult().get(); EvaluationMetricResult result = mse.getResult().get();
assertThat(result, equalTo(new MeanSquaredError.Result(0.0))); assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
} }
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
} }

View File

@ -9,8 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +16,9 @@ import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockExtendedStats;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RSquaredTests extends AbstractSerializingTestCase<RSquared> { public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
@ -45,10 +43,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 10_111), mockSingleValue("residual_sum_of_squares", 10_111),
createExtendedStatsAgg("extended_stats_actual", 155.23, 1000), mockExtendedStats("extended_stats_actual", 155.23, 1000),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), mockExtendedStats("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
RSquared rSquared = new RSquared(); RSquared rSquared = new RSquared();
@ -61,10 +59,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluateWithZeroCount() { public void testEvaluateWithZeroCount() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 0), mockSingleValue("residual_sum_of_squares", 0),
createExtendedStatsAgg("extended_stats_actual", 0.0, 0), mockExtendedStats("extended_stats_actual", 0.0, 0),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), mockExtendedStats("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
RSquared rSquared = new RSquared(); RSquared rSquared = new RSquared();
@ -76,10 +74,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluateWithSingleCountZeroVariance() { public void testEvaluateWithSingleCountZeroVariance() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 1), mockSingleValue("residual_sum_of_squares", 1),
createExtendedStatsAgg("extended_stats_actual", 0.0, 1), mockExtendedStats("extended_stats_actual", 0.0, 1),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), mockExtendedStats("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
RSquared rSquared = new RSquared(); RSquared rSquared = new RSquared();
@ -91,7 +89,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate_GivenMissingAggs() { public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList( Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
RSquared rSquared = new RSquared(); RSquared rSquared = new RSquared();
@ -103,8 +101,8 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate_GivenMissingExtendedStatsAgg() { public void testEvaluate_GivenMissingExtendedStatsAgg() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377), mockSingleValue("some_other_single_metric_agg", 0.2377),
createSingleMetricAgg("residual_sum_of_squares", 0.2377) mockSingleValue("residual_sum_of_squares", 0.2377)
)); ));
RSquared rSquared = new RSquared(); RSquared rSquared = new RSquared();
@ -116,8 +114,8 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() { public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377), mockSingleValue("some_other_single_metric_agg", 0.2377),
createExtendedStatsAgg("extended_stats_actual",100, 50) mockExtendedStats("extended_stats_actual",100, 50)
)); ));
RSquared rSquared = new RSquared(); RSquared rSquared = new RSquared();
@ -126,19 +124,4 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
EvaluationMetricResult result = rSquared.getResult().get(); EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0))); assertThat(result, equalTo(new RSquared.Result(0.0)));
} }
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
private static ExtendedStats createExtendedStatsAgg(String name, double variance, long count) {
ExtendedStats agg = mock(ExtendedStats.class);
when(agg.getName()).thenReturn(name);
when(agg.getVariance()).thenReturn(variance);
when(agg.getCount()).thenReturn(count);
return agg;
}
} }

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +17,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionMatrix> { public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionMatrix> {
@ -50,14 +48,14 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("confusion_matrix_at_0.25_TP", 1L), mockFilter("confusion_matrix_at_0.25_TP", 1L),
createFilterAgg("confusion_matrix_at_0.25_FP", 2L), mockFilter("confusion_matrix_at_0.25_FP", 2L),
createFilterAgg("confusion_matrix_at_0.25_TN", 3L), mockFilter("confusion_matrix_at_0.25_TN", 3L),
createFilterAgg("confusion_matrix_at_0.25_FN", 4L), mockFilter("confusion_matrix_at_0.25_FN", 4L),
createFilterAgg("confusion_matrix_at_0.5_TP", 5L), mockFilter("confusion_matrix_at_0.5_TP", 5L),
createFilterAgg("confusion_matrix_at_0.5_FP", 6L), mockFilter("confusion_matrix_at_0.5_FP", 6L),
createFilterAgg("confusion_matrix_at_0.5_TN", 7L), mockFilter("confusion_matrix_at_0.5_TN", 7L),
createFilterAgg("confusion_matrix_at_0.5_FN", 8L) mockFilter("confusion_matrix_at_0.5_FN", 8L)
)); ));
ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5)); ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
@ -66,11 +64,4 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}"; String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
assertThat(Strings.toString(result), equalTo(expected)); assertThat(Strings.toString(result), equalTo(expected));
} }
private static Filter createFilterAgg(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
} }

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +17,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class PrecisionTests extends AbstractSerializingTestCase<Precision> { public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
@ -50,12 +48,12 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("precision_at_0.25_TP", 1L), mockFilter("precision_at_0.25_TP", 1L),
createFilterAgg("precision_at_0.25_FP", 4L), mockFilter("precision_at_0.25_FP", 4L),
createFilterAgg("precision_at_0.5_TP", 3L), mockFilter("precision_at_0.5_TP", 3L),
createFilterAgg("precision_at_0.5_FP", 1L), mockFilter("precision_at_0.5_FP", 1L),
createFilterAgg("precision_at_0.75_TP", 5L), mockFilter("precision_at_0.75_TP", 5L),
createFilterAgg("precision_at_0.75_FP", 0L) mockFilter("precision_at_0.75_FP", 0L)
)); ));
Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75)); Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
@ -67,8 +65,8 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
public void testEvaluate_GivenZeroTpAndFp() { public void testEvaluate_GivenZeroTpAndFp() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("precision_at_1.0_TP", 0L), mockFilter("precision_at_1.0_TP", 0L),
createFilterAgg("precision_at_1.0_FP", 0L) mockFilter("precision_at_1.0_FP", 0L)
)); ));
Precision precision = new Precision(Arrays.asList(1.0)); Precision precision = new Precision(Arrays.asList(1.0));
@ -77,11 +75,4 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
String expected = "{\"1.0\":0.0}"; String expected = "{\"1.0\":0.0}";
assertThat(Strings.toString(result), equalTo(expected)); assertThat(Strings.toString(result), equalTo(expected));
} }
private static Filter createFilterAgg(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
} }

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -18,9 +17,8 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RecallTests extends AbstractSerializingTestCase<Recall> { public class RecallTests extends AbstractSerializingTestCase<Recall> {
@ -50,12 +48,12 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("recall_at_0.25_TP", 1L), mockFilter("recall_at_0.25_TP", 1L),
createFilterAgg("recall_at_0.25_FN", 4L), mockFilter("recall_at_0.25_FN", 4L),
createFilterAgg("recall_at_0.5_TP", 3L), mockFilter("recall_at_0.5_TP", 3L),
createFilterAgg("recall_at_0.5_FN", 1L), mockFilter("recall_at_0.5_FN", 1L),
createFilterAgg("recall_at_0.75_TP", 5L), mockFilter("recall_at_0.75_TP", 5L),
createFilterAgg("recall_at_0.75_FN", 0L) mockFilter("recall_at_0.75_FN", 0L)
)); ));
Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75)); Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
@ -67,8 +65,8 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
public void testEvaluate_GivenZeroTpAndFp() { public void testEvaluate_GivenZeroTpAndFp() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("recall_at_1.0_TP", 0L), mockFilter("recall_at_1.0_TP", 0L),
createFilterAgg("recall_at_1.0_FN", 0L) mockFilter("recall_at_1.0_FN", 0L)
)); ));
Recall recall = new Recall(Arrays.asList(1.0)); Recall recall = new Recall(Arrays.asList(1.0));
@ -77,11 +75,4 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
String expected = "{\"1.0\":0.0}"; String expected = "{\"1.0\":0.0}";
assertThat(Strings.toString(result), equalTo(expected)); assertThat(Strings.toString(result), equalTo(expected));
} }
private static Filter createFilterAgg(String name, long docCount) {
Filter agg = mock(Filter.class);
when(agg.getName()).thenReturn(name);
when(agg.getDocCount()).thenReturn(docCount);
return agg;
}
} }