diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 92d570ba0a3..56319e35153 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -33,9 +33,6 @@ import org.elasticsearch.client.indices.CreateIndexRequest; import org.elasticsearch.client.indices.GetIndexRequest; import org.elasticsearch.client.ml.CloseJobRequest; import org.elasticsearch.client.ml.CloseJobResponse; -import org.elasticsearch.client.ml.DeleteTrainedModelRequest; -import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest; -import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; @@ -48,8 +45,11 @@ import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.DeleteTrainedModelRequest; import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.EvaluateDataFrameResponse; +import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -135,8 +135,6 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; @@ -1852,9 +1850,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { accuracyResult.getActualClasses(), equalTo( Arrays.asList( - new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly - new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly - new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly + // 3 out of 5 examples labeled as "cat" were classified correctly + new AccuracyMetric.ActualClass("cat", 5, 0.6), + // 3 out of 4 examples labeled as "dog" were classified correctly + new AccuracyMetric.ActualClass("dog", 4, 0.75), + // no examples labeled as "ant" were classified correctly + new AccuracyMetric.ActualClass("ant", 1, 0.0)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly } { // No size provided for MulticlassConfusionMatrixMetric, default used instead @@ -1876,20 +1877,29 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { mcmResult.getConfusionMatrix(), equalTo( Arrays.asList( - new ActualClass( + new MulticlassConfusionMatrixMetric.ActualClass( "ant", 1L, - Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), + Arrays.asList( + new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L), + new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L), + new MulticlassConfusionMatrixMetric.PredictedClass("dog", 0L)), 0L), - new ActualClass( + new MulticlassConfusionMatrixMetric.ActualClass( "cat", 5L, - Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), + Arrays.asList( + new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L), + new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L), + new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)), 1L), - new ActualClass( + new MulticlassConfusionMatrixMetric.ActualClass( "dog", 4L, - Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), + Arrays.asList( + new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L), + new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L), + new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)), 0L)))); assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L)); } @@ -1912,8 +1922,20 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { mcmResult.getConfusionMatrix(), equalTo( Arrays.asList( - new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L), - new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L) + new MulticlassConfusionMatrixMetric.ActualClass( + "cat", + 5L, + Arrays.asList( + new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L), + new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)), + 1L), + new MulticlassConfusionMatrixMetric.ActualClass( + "dog", + 4L, + Arrays.asList( + new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L), + new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)), + 0L) ))); assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L)); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index 70740a3268f..f6b7459b104 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -20,36 +20,56 @@ package org.elasticsearch.client.ml; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetricResultTests; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Predicate; public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase { public static EvaluateDataFrameResponse randomResponse() { - List metrics = new ArrayList<>(); - if (randomBoolean()) { - metrics.add(AucRocMetricResultTests.randomResult()); + String evaluationName = randomFrom(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME); + List metrics; + switch (evaluationName) { + case BinarySoftClassification.NAME: + metrics = randomSubsetOf( + Arrays.asList( + AucRocMetricResultTests.randomResult(), + PrecisionMetricResultTests.randomResult(), + RecallMetricResultTests.randomResult(), + ConfusionMatrixMetricResultTests.randomResult())); + break; + case Regression.NAME: + metrics = randomSubsetOf( + Arrays.asList( + MeanSquaredErrorMetricResultTests.randomResult(), + RSquaredMetricResultTests.randomResult())); + break; + case Classification.NAME: + metrics = randomSubsetOf( + Arrays.asList( + AccuracyMetricResultTests.randomResult(), + MulticlassConfusionMatrixMetricResultTests.randomResult())); + break; + default: + throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement"); } - if (randomBoolean()) { - metrics.add(PrecisionMetricResultTests.randomResult()); - } - if (randomBoolean()) { - metrics.add(RecallMetricResultTests.randomResult()); - } - if (randomBoolean()) { - metrics.add(ConfusionMatrixMetricResultTests.randomResult()); - } - if (randomBoolean()) { - metrics.add(MeanSquaredErrorMetricResultTests.randomResult()); - } - return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); + return new EvaluateDataFrameResponse(evaluationName, metrics); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java index 4e6557b4f58..df48ef3123d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java @@ -31,15 +31,14 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; -public class AccuracyMetricResultTests extends AbstractXContentTestCase { +public class AccuracyMetricResultTests extends AbstractXContentTestCase { @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); } - @Override - protected AccuracyMetric.Result createTestInstance() { + public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List actualClasses = new ArrayList<>(numClasses); @@ -52,8 +51,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase metrics = - randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom())); + randomSubsetOf( + Arrays.asList( + AccuracyMetricTests.createRandom(), + MulticlassConfusionMatrixMetricTests.createRandom())); return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java index 55b74eb94ea..b08b10f3203 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -40,8 +40,7 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); } - @Override - protected Result createTestInstance() { + public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List actualClasses = new ArrayList<>(numClasses); @@ -60,6 +59,11 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null); } + @Override + protected Result createTestInstance() { + return randomResult(); + } + @Override protected Result doParseInstance(XContentParser parser) throws IOException { return Result.fromXContent(parser); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetricAucRocPointTests.java similarity index 92% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetricAucRocPointTests.java index 825adcd2060..93f2b25a734 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetricAucRocPointTests.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml; +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetricResultTests.java similarity index 88% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetricResultTests.java index 9ea7689d60f..bd8fc8e790e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetricResultTests.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml; +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -27,11 +26,11 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint; +import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricAucRocPointTests.randomPoint; public class AucRocMetricResultTests extends AbstractXContentTestCase { - static AucRocMetric.Result randomResult() { + public static AucRocMetric.Result randomResult() { return new AucRocMetric.Result( randomDouble(), Stream diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetricConfusionMatrixTests.java similarity index 92% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetricConfusionMatrixTests.java index b54bcd53fc4..39897112f38 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetricConfusionMatrixTests.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml; +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetricResultTests.java similarity index 87% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetricResultTests.java index c4b299a96b5..42819e077d8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetricResultTests.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml; +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -27,11 +26,11 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix; +import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix; public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase { - static ConfusionMatrixMetric.Result randomResult() { + public static ConfusionMatrixMetric.Result randomResult() { return new ConfusionMatrixMetric.Result( Stream .generate(() -> randomConfusionMatrix()) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetricResultTests.java similarity index 91% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetricResultTests.java index 607adacebb8..7ece003ef22 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetricResultTests.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml; +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -29,7 +28,7 @@ import java.util.stream.Stream; public class PrecisionMetricResultTests extends AbstractXContentTestCase { - static PrecisionMetric.Result randomResult() { + public static PrecisionMetric.Result randomResult() { return new PrecisionMetric.Result( Stream .generate(() -> randomDouble()) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetricResultTests.java similarity index 91% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetricResultTests.java index 138875007e3..85d9b38075e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetricResultTests.java @@ -16,9 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml; +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; -import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -29,7 +28,7 @@ import java.util.stream.Stream; public class RecallMetricResultTests extends AbstractXContentTestCase { - static RecallMetric.Result randomResult() { + public static RecallMetric.Result randomResult() { return new RecallMetric.Result( Stream .generate(() -> randomDouble()) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MockAggregations.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MockAggregations.java new file mode 100644 index 00000000000..d5919930cb8 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MockAggregations.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.search.aggregations.metrics.ExtendedStats; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; + +import java.util.Collections; +import java.util.List; + +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public final class MockAggregations { + + public static Terms mockTerms(String name) { + return mockTerms(name, Collections.emptyList(), 0); + } + + public static Terms mockTerms(String name, List buckets, long sumOfOtherDocCounts) { + Terms agg = mock(Terms.class); + when(agg.getName()).thenReturn(name); + doReturn(buckets).when(agg).getBuckets(); + when(agg.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts); + return agg; + } + + public static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) { + Terms.Bucket bucket = mock(Terms.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(key); + when(bucket.getAggregations()).thenReturn(subAggs); + return bucket; + } + + public static Filters mockFilters(String name) { + return mockFilters(name, Collections.emptyList()); + } + + public static Filters mockFilters(String name, List buckets) { + Filters agg = mock(Filters.class); + when(agg.getName()).thenReturn(name); + doReturn(buckets).when(agg).getBuckets(); + return agg; + } + + public static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) { + Filters.Bucket bucket = mockFiltersBucket(key, docCount); + when(bucket.getAggregations()).thenReturn(subAggs); + return bucket; + } + + public static Filters.Bucket mockFiltersBucket(String key, long docCount) { + Filters.Bucket bucket = mock(Filters.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(key); + when(bucket.getDocCount()).thenReturn(docCount); + return bucket; + } + + public static Filter mockFilter(String name, long docCount) { + Filter agg = mock(Filter.class); + when(agg.getName()).thenReturn(name); + when(agg.getDocCount()).thenReturn(docCount); + return agg; + } + + public static NumericMetricsAggregation.SingleValue mockSingleValue(String name, double value) { + NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); + when(agg.getName()).thenReturn(name); + when(agg.value()).thenReturn(value); + return agg; + } + + public static Cardinality mockCardinality(String name, long value) { + Cardinality agg = mock(Cardinality.class); + when(agg.getName()).thenReturn(name); + when(agg.getValue()).thenReturn(value); + return agg; + } + + public static ExtendedStats mockExtendedStats(String name, double variance, long count) { + ExtendedStats agg = mock(ExtendedStats.class); + when(agg.getName()).thenReturn(name); + when(agg.getVariance()).thenReturn(variance); + when(agg.getCount()).thenReturn(count); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index 1c4caa0c51d..c5e36564c57 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -8,17 +8,15 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.bucket.terms.Terms; -import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class AccuracyTests extends AbstractSerializingTestCase { @@ -48,9 +46,9 @@ public class AccuracyTests extends AbstractSerializingTestCase { public void testProcess() { Aggregations aggs = new Aggregations(Arrays.asList( - createTermsAgg("classification_classes"), - createSingleMetricAgg("classification_overall_accuracy", 0.8123), - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockTerms("classification_classes"), + mockSingleValue("classification_overall_accuracy", 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) )); Accuracy accuracy = new Accuracy(); @@ -62,16 +60,16 @@ public class AccuracyTests extends AbstractSerializingTestCase { public void testProcess_GivenMissingAgg() { { Aggregations aggs = new Aggregations(Arrays.asList( - createTermsAgg("classification_classes"), - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockTerms("classification_classes"), + mockSingleValue("some_other_single_metric_agg", 0.2377) )); Accuracy accuracy = new Accuracy(); expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); } { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("classification_overall_accuracy", 0.8123), - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockSingleValue("classification_overall_accuracy", 0.8123), + mockSingleValue("some_other_single_metric_agg", 0.2377) )); Accuracy accuracy = new Accuracy(); expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); @@ -81,32 +79,19 @@ public class AccuracyTests extends AbstractSerializingTestCase { public void testProcess_GivenAggOfWrongType() { { Aggregations aggs = new Aggregations(Arrays.asList( - createTermsAgg("classification_classes"), - createTermsAgg("classification_overall_accuracy") + mockTerms("classification_classes"), + mockTerms("classification_overall_accuracy") )); Accuracy accuracy = new Accuracy(); expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); } { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("classification_classes", 1.0), - createSingleMetricAgg("classification_overall_accuracy", 0.8123) + mockSingleValue("classification_classes", 1.0), + mockSingleValue("classification_overall_accuracy", 0.8123) )); Accuracy accuracy = new Accuracy(); expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); } } - - private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { - NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); - when(agg.getName()).thenReturn(name); - when(agg.value()).thenReturn(value); - return agg; - } - - private static Terms createTermsAgg(String name) { - Terms agg = mock(Terms.class); - when(agg.getName()).thenReturn(name); - return agg; - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index bceee8b399e..6deb06cf66d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -52,7 +52,10 @@ public class ClassificationTests extends AbstractSerializingTestCase metrics = - randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom())); + randomSubsetOf( + Arrays.asList( + AccuracyTests.createRandom(), + MulticlassConfusionMatrixTests.createRandom())); return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index 0991093c9ee..bb6c484a545 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -10,9 +10,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.bucket.filter.Filters; -import org.elasticsearch.search.aggregations.bucket.terms.Terms; -import org.elasticsearch.search.aggregations.metrics.Cardinality; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; @@ -21,15 +18,17 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Optional; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase { @@ -77,7 +76,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); List aggs = confusionMatrix.aggs("act", "pred"); assertThat(aggs, is(not(empty()))); - assertThat(confusionMatrix.getResult(), equalTo(Optional.empty())); + assertThat(confusionMatrix.getResult(), isEmpty()); } public void testEvaluate() { @@ -163,46 +162,4 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); assertThat(result.getOtherActualClassCount(), equalTo(3L)); } - - private static Terms mockTerms(String name, List buckets, long sumOfOtherDocCounts) { - Terms aggregation = mock(Terms.class); - when(aggregation.getName()).thenReturn(name); - doReturn(buckets).when(aggregation).getBuckets(); - when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts); - return aggregation; - } - - private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) { - Terms.Bucket bucket = mock(Terms.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(key); - when(bucket.getAggregations()).thenReturn(subAggs); - return bucket; - } - - private static Filters mockFilters(String name, List buckets) { - Filters aggregation = mock(Filters.class); - when(aggregation.getName()).thenReturn(name); - doReturn(buckets).when(aggregation).getBuckets(); - return aggregation; - } - - private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) { - Filters.Bucket bucket = mockFiltersBucket(key, docCount); - when(bucket.getAggregations()).thenReturn(subAggs); - return bucket; - } - - private static Filters.Bucket mockFiltersBucket(String key, long docCount) { - Filters.Bucket bucket = mock(Filters.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(key); - when(bucket.getDocCount()).thenReturn(docCount); - return bucket; - } - - private static Cardinality mockCardinality(String name, long value) { - Cardinality aggregation = mock(Cardinality.class); - when(aggregation.getName()).thenReturn(name); - when(aggregation.getValue()).thenReturn(value); - return aggregation; - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java index 2516b2fea94..5679655c165 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; @@ -17,9 +16,8 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class MeanSquaredErrorTests extends AbstractSerializingTestCase { @@ -44,8 +42,8 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase { @@ -45,10 +43,10 @@ public class RSquaredTests extends AbstractSerializingTestCase { public void testEvaluate() { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("residual_sum_of_squares", 10_111), - createExtendedStatsAgg("extended_stats_actual", 155.23, 1000), - createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockSingleValue("residual_sum_of_squares", 10_111), + mockExtendedStats("extended_stats_actual", 155.23, 1000), + mockExtendedStats("some_other_extended_stats",99.1, 10_000), + mockSingleValue("some_other_single_metric_agg", 0.2377) )); RSquared rSquared = new RSquared(); @@ -61,10 +59,10 @@ public class RSquaredTests extends AbstractSerializingTestCase { public void testEvaluateWithZeroCount() { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("residual_sum_of_squares", 0), - createExtendedStatsAgg("extended_stats_actual", 0.0, 0), - createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockSingleValue("residual_sum_of_squares", 0), + mockExtendedStats("extended_stats_actual", 0.0, 0), + mockExtendedStats("some_other_extended_stats",99.1, 10_000), + mockSingleValue("some_other_single_metric_agg", 0.2377) )); RSquared rSquared = new RSquared(); @@ -76,10 +74,10 @@ public class RSquaredTests extends AbstractSerializingTestCase { public void testEvaluateWithSingleCountZeroVariance() { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("residual_sum_of_squares", 1), - createExtendedStatsAgg("extended_stats_actual", 0.0, 1), - createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockSingleValue("residual_sum_of_squares", 1), + mockExtendedStats("extended_stats_actual", 0.0, 1), + mockExtendedStats("some_other_extended_stats",99.1, 10_000), + mockSingleValue("some_other_single_metric_agg", 0.2377) )); RSquared rSquared = new RSquared(); @@ -91,7 +89,7 @@ public class RSquaredTests extends AbstractSerializingTestCase { public void testEvaluate_GivenMissingAggs() { Aggregations aggs = new Aggregations(Collections.singletonList( - createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + mockSingleValue("some_other_single_metric_agg", 0.2377) )); RSquared rSquared = new RSquared(); @@ -103,8 +101,8 @@ public class RSquaredTests extends AbstractSerializingTestCase { public void testEvaluate_GivenMissingExtendedStatsAgg() { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("some_other_single_metric_agg", 0.2377), - createSingleMetricAgg("residual_sum_of_squares", 0.2377) + mockSingleValue("some_other_single_metric_agg", 0.2377), + mockSingleValue("residual_sum_of_squares", 0.2377) )); RSquared rSquared = new RSquared(); @@ -116,8 +114,8 @@ public class RSquaredTests extends AbstractSerializingTestCase { public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() { Aggregations aggs = new Aggregations(Arrays.asList( - createSingleMetricAgg("some_other_single_metric_agg", 0.2377), - createExtendedStatsAgg("extended_stats_actual",100, 50) + mockSingleValue("some_other_single_metric_agg", 0.2377), + mockExtendedStats("extended_stats_actual",100, 50) )); RSquared rSquared = new RSquared(); @@ -126,19 +124,4 @@ public class RSquaredTests extends AbstractSerializingTestCase { EvaluationMetricResult result = rSquared.getResult().get(); assertThat(result, equalTo(new RSquared.Result(0.0))); } - - private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { - NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); - when(agg.getName()).thenReturn(name); - when(agg.value()).thenReturn(value); - return agg; - } - - private static ExtendedStats createExtendedStatsAgg(String name, double variance, long count) { - ExtendedStats agg = mock(ExtendedStats.class); - when(agg.getName()).thenReturn(name); - when(agg.getVariance()).thenReturn(variance); - when(agg.getCount()).thenReturn(count); - return agg; - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java index cf54131af13..84194bd0bac 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.bucket.filter.Filter; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; @@ -18,9 +17,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class ConfusionMatrixTests extends AbstractSerializingTestCase { @@ -50,14 +48,14 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase { @@ -50,12 +48,12 @@ public class PrecisionTests extends AbstractSerializingTestCase { public void testEvaluate() { Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("precision_at_0.25_TP", 1L), - createFilterAgg("precision_at_0.25_FP", 4L), - createFilterAgg("precision_at_0.5_TP", 3L), - createFilterAgg("precision_at_0.5_FP", 1L), - createFilterAgg("precision_at_0.75_TP", 5L), - createFilterAgg("precision_at_0.75_FP", 0L) + mockFilter("precision_at_0.25_TP", 1L), + mockFilter("precision_at_0.25_FP", 4L), + mockFilter("precision_at_0.5_TP", 3L), + mockFilter("precision_at_0.5_FP", 1L), + mockFilter("precision_at_0.75_TP", 5L), + mockFilter("precision_at_0.75_FP", 0L) )); Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75)); @@ -67,8 +65,8 @@ public class PrecisionTests extends AbstractSerializingTestCase { public void testEvaluate_GivenZeroTpAndFp() { Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("precision_at_1.0_TP", 0L), - createFilterAgg("precision_at_1.0_FP", 0L) + mockFilter("precision_at_1.0_TP", 0L), + mockFilter("precision_at_1.0_FP", 0L) )); Precision precision = new Precision(Arrays.asList(1.0)); @@ -77,11 +75,4 @@ public class PrecisionTests extends AbstractSerializingTestCase { String expected = "{\"1.0\":0.0}"; assertThat(Strings.toString(result), equalTo(expected)); } - - private static Filter createFilterAgg(String name, long docCount) { - Filter agg = mock(Filter.class); - when(agg.getName()).thenReturn(name); - when(agg.getDocCount()).thenReturn(docCount); - return agg; - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java index 009805425cd..343d1905955 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.bucket.filter.Filter; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; @@ -18,9 +17,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter; import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class RecallTests extends AbstractSerializingTestCase { @@ -50,12 +48,12 @@ public class RecallTests extends AbstractSerializingTestCase { public void testEvaluate() { Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("recall_at_0.25_TP", 1L), - createFilterAgg("recall_at_0.25_FN", 4L), - createFilterAgg("recall_at_0.5_TP", 3L), - createFilterAgg("recall_at_0.5_FN", 1L), - createFilterAgg("recall_at_0.75_TP", 5L), - createFilterAgg("recall_at_0.75_FN", 0L) + mockFilter("recall_at_0.25_TP", 1L), + mockFilter("recall_at_0.25_FN", 4L), + mockFilter("recall_at_0.5_TP", 3L), + mockFilter("recall_at_0.5_FN", 1L), + mockFilter("recall_at_0.75_TP", 5L), + mockFilter("recall_at_0.75_FN", 0L) )); Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75)); @@ -67,8 +65,8 @@ public class RecallTests extends AbstractSerializingTestCase { public void testEvaluate_GivenZeroTpAndFp() { Aggregations aggs = new Aggregations(Arrays.asList( - createFilterAgg("recall_at_1.0_TP", 0L), - createFilterAgg("recall_at_1.0_FN", 0L) + mockFilter("recall_at_1.0_TP", 0L), + mockFilter("recall_at_1.0_FN", 0L) )); Recall recall = new Recall(Arrays.asList(1.0)); @@ -77,11 +75,4 @@ public class RecallTests extends AbstractSerializingTestCase { String expected = "{\"1.0\":0.0}"; assertThat(Strings.toString(result), equalTo(expected)); } - - private static Filter createFilterAgg(String name, long docCount) { - Filter agg = mock(Filter.class); - when(agg.getName()).thenReturn(name); - when(agg.getDocCount()).thenReturn(docCount); - return agg; - } }