parent
fbb92f527a
commit
a3f88595d7
|
@ -33,9 +33,6 @@ import org.elasticsearch.client.indices.CreateIndexRequest;
|
||||||
import org.elasticsearch.client.indices.GetIndexRequest;
|
import org.elasticsearch.client.indices.GetIndexRequest;
|
||||||
import org.elasticsearch.client.ml.CloseJobRequest;
|
import org.elasticsearch.client.ml.CloseJobRequest;
|
||||||
import org.elasticsearch.client.ml.CloseJobResponse;
|
import org.elasticsearch.client.ml.CloseJobResponse;
|
||||||
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
|
|
||||||
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
|
|
||||||
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
|
|
||||||
import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
|
import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
|
||||||
import org.elasticsearch.client.ml.DeleteCalendarJobRequest;
|
import org.elasticsearch.client.ml.DeleteCalendarJobRequest;
|
||||||
import org.elasticsearch.client.ml.DeleteCalendarRequest;
|
import org.elasticsearch.client.ml.DeleteCalendarRequest;
|
||||||
|
@ -48,8 +45,11 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
|
||||||
import org.elasticsearch.client.ml.DeleteJobRequest;
|
import org.elasticsearch.client.ml.DeleteJobRequest;
|
||||||
import org.elasticsearch.client.ml.DeleteJobResponse;
|
import org.elasticsearch.client.ml.DeleteJobResponse;
|
||||||
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
|
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
|
||||||
|
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
|
||||||
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
|
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
|
||||||
import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
|
import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
|
||||||
|
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
|
||||||
|
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
|
||||||
import org.elasticsearch.client.ml.FindFileStructureRequest;
|
import org.elasticsearch.client.ml.FindFileStructureRequest;
|
||||||
import org.elasticsearch.client.ml.FindFileStructureResponse;
|
import org.elasticsearch.client.ml.FindFileStructureResponse;
|
||||||
import org.elasticsearch.client.ml.FlushJobRequest;
|
import org.elasticsearch.client.ml.FlushJobRequest;
|
||||||
|
@ -135,8 +135,6 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||||
|
@ -1852,9 +1850,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
accuracyResult.getActualClasses(),
|
accuracyResult.getActualClasses(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly
|
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||||
new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly
|
new AccuracyMetric.ActualClass("cat", 5, 0.6),
|
||||||
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly
|
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||||
|
new AccuracyMetric.ActualClass("dog", 4, 0.75),
|
||||||
|
// no examples labeled as "ant" were classified correctly
|
||||||
|
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
|
||||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
||||||
}
|
}
|
||||||
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
||||||
|
@ -1876,20 +1877,29 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
mcmResult.getConfusionMatrix(),
|
mcmResult.getConfusionMatrix(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
new ActualClass(
|
new MulticlassConfusionMatrixMetric.ActualClass(
|
||||||
"ant",
|
"ant",
|
||||||
1L,
|
1L,
|
||||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
|
Arrays.asList(
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 0L)),
|
||||||
0L),
|
0L),
|
||||||
new ActualClass(
|
new MulticlassConfusionMatrixMetric.ActualClass(
|
||||||
"cat",
|
"cat",
|
||||||
5L,
|
5L,
|
||||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
|
Arrays.asList(
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)),
|
||||||
1L),
|
1L),
|
||||||
new ActualClass(
|
new MulticlassConfusionMatrixMetric.ActualClass(
|
||||||
"dog",
|
"dog",
|
||||||
4L,
|
4L,
|
||||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
Arrays.asList(
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("ant", 0L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)),
|
||||||
0L))));
|
0L))));
|
||||||
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
|
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
|
||||||
}
|
}
|
||||||
|
@ -1912,8 +1922,20 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
mcmResult.getConfusionMatrix(),
|
mcmResult.getConfusionMatrix(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
|
new MulticlassConfusionMatrixMetric.ActualClass(
|
||||||
new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
|
"cat",
|
||||||
|
5L,
|
||||||
|
Arrays.asList(
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 3L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 1L)),
|
||||||
|
1L),
|
||||||
|
new MulticlassConfusionMatrixMetric.ActualClass(
|
||||||
|
"dog",
|
||||||
|
4L,
|
||||||
|
Arrays.asList(
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("cat", 1L),
|
||||||
|
new MulticlassConfusionMatrixMetric.PredictedClass("dog", 3L)),
|
||||||
|
0L)
|
||||||
)));
|
)));
|
||||||
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
|
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,36 +20,56 @@ package org.elasticsearch.client.ml;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetricResultTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetricResultTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricResultTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricResultTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetricResultTests;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetricResultTests;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<EvaluateDataFrameResponse> {
|
public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<EvaluateDataFrameResponse> {
|
||||||
|
|
||||||
public static EvaluateDataFrameResponse randomResponse() {
|
public static EvaluateDataFrameResponse randomResponse() {
|
||||||
List<EvaluationMetric.Result> metrics = new ArrayList<>();
|
String evaluationName = randomFrom(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME);
|
||||||
if (randomBoolean()) {
|
List<EvaluationMetric.Result> metrics;
|
||||||
metrics.add(AucRocMetricResultTests.randomResult());
|
switch (evaluationName) {
|
||||||
|
case BinarySoftClassification.NAME:
|
||||||
|
metrics = randomSubsetOf(
|
||||||
|
Arrays.asList(
|
||||||
|
AucRocMetricResultTests.randomResult(),
|
||||||
|
PrecisionMetricResultTests.randomResult(),
|
||||||
|
RecallMetricResultTests.randomResult(),
|
||||||
|
ConfusionMatrixMetricResultTests.randomResult()));
|
||||||
|
break;
|
||||||
|
case Regression.NAME:
|
||||||
|
metrics = randomSubsetOf(
|
||||||
|
Arrays.asList(
|
||||||
|
MeanSquaredErrorMetricResultTests.randomResult(),
|
||||||
|
RSquaredMetricResultTests.randomResult()));
|
||||||
|
break;
|
||||||
|
case Classification.NAME:
|
||||||
|
metrics = randomSubsetOf(
|
||||||
|
Arrays.asList(
|
||||||
|
AccuracyMetricResultTests.randomResult(),
|
||||||
|
MulticlassConfusionMatrixMetricResultTests.randomResult()));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement");
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
return new EvaluateDataFrameResponse(evaluationName, metrics);
|
||||||
metrics.add(PrecisionMetricResultTests.randomResult());
|
|
||||||
}
|
|
||||||
if (randomBoolean()) {
|
|
||||||
metrics.add(RecallMetricResultTests.randomResult());
|
|
||||||
}
|
|
||||||
if (randomBoolean()) {
|
|
||||||
metrics.add(ConfusionMatrixMetricResultTests.randomResult());
|
|
||||||
}
|
|
||||||
if (randomBoolean()) {
|
|
||||||
metrics.add(MeanSquaredErrorMetricResultTests.randomResult());
|
|
||||||
}
|
|
||||||
return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -31,15 +31,14 @@ import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
public class AccuracyMetricResultTests extends AbstractXContentTestCase<AccuracyMetric.Result> {
|
public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NamedXContentRegistry xContentRegistry() {
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public static Result randomResult() {
|
||||||
protected AccuracyMetric.Result createTestInstance() {
|
|
||||||
int numClasses = randomIntBetween(2, 100);
|
int numClasses = randomIntBetween(2, 100);
|
||||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||||
|
@ -52,8 +51,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Accuracy
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
protected Result createTestInstance() {
|
||||||
return AccuracyMetric.Result.fromXContent(parser);
|
return randomResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Result doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return Result.fromXContent(parser);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -38,7 +38,10 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
||||||
|
|
||||||
static Classification createRandom() {
|
static Classification createRandom() {
|
||||||
List<EvaluationMetric> metrics =
|
List<EvaluationMetric> metrics =
|
||||||
randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom()));
|
randomSubsetOf(
|
||||||
|
Arrays.asList(
|
||||||
|
AccuracyMetricTests.createRandom(),
|
||||||
|
MulticlassConfusionMatrixMetricTests.createRandom()));
|
||||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,8 +40,7 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
|
||||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public static Result randomResult() {
|
||||||
protected Result createTestInstance() {
|
|
||||||
int numClasses = randomIntBetween(2, 100);
|
int numClasses = randomIntBetween(2, 100);
|
||||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||||
|
@ -60,6 +59,11 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
|
||||||
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
|
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Result createTestInstance() {
|
||||||
|
return randomResult();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Result doParseInstance(XContentParser parser) throws IOException {
|
protected Result doParseInstance(XContentParser parser) throws IOException {
|
||||||
return Result.fromXContent(parser);
|
return Result.fromXContent(parser);
|
||||||
|
|
|
@ -16,9 +16,8 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml;
|
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
|
@ -16,9 +16,8 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml;
|
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
@ -27,11 +26,11 @@ import java.util.function.Predicate;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint;
|
import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetricAucRocPointTests.randomPoint;
|
||||||
|
|
||||||
public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetric.Result> {
|
public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetric.Result> {
|
||||||
|
|
||||||
static AucRocMetric.Result randomResult() {
|
public static AucRocMetric.Result randomResult() {
|
||||||
return new AucRocMetric.Result(
|
return new AucRocMetric.Result(
|
||||||
randomDouble(),
|
randomDouble(),
|
||||||
Stream
|
Stream
|
|
@ -16,9 +16,8 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml;
|
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
|
@ -16,9 +16,8 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml;
|
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
@ -27,11 +26,11 @@ import java.util.function.Predicate;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix;
|
import static org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix;
|
||||||
|
|
||||||
public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase<ConfusionMatrixMetric.Result> {
|
public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase<ConfusionMatrixMetric.Result> {
|
||||||
|
|
||||||
static ConfusionMatrixMetric.Result randomResult() {
|
public static ConfusionMatrixMetric.Result randomResult() {
|
||||||
return new ConfusionMatrixMetric.Result(
|
return new ConfusionMatrixMetric.Result(
|
||||||
Stream
|
Stream
|
||||||
.generate(() -> randomConfusionMatrix())
|
.generate(() -> randomConfusionMatrix())
|
|
@ -16,9 +16,8 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml;
|
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
@ -29,7 +28,7 @@ import java.util.stream.Stream;
|
||||||
|
|
||||||
public class PrecisionMetricResultTests extends AbstractXContentTestCase<PrecisionMetric.Result> {
|
public class PrecisionMetricResultTests extends AbstractXContentTestCase<PrecisionMetric.Result> {
|
||||||
|
|
||||||
static PrecisionMetric.Result randomResult() {
|
public static PrecisionMetric.Result randomResult() {
|
||||||
return new PrecisionMetric.Result(
|
return new PrecisionMetric.Result(
|
||||||
Stream
|
Stream
|
||||||
.generate(() -> randomDouble())
|
.generate(() -> randomDouble())
|
|
@ -16,9 +16,8 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml;
|
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
@ -29,7 +28,7 @@ import java.util.stream.Stream;
|
||||||
|
|
||||||
public class RecallMetricResultTests extends AbstractXContentTestCase<RecallMetric.Result> {
|
public class RecallMetricResultTests extends AbstractXContentTestCase<RecallMetric.Result> {
|
||||||
|
|
||||||
static RecallMetric.Result randomResult() {
|
public static RecallMetric.Result randomResult() {
|
||||||
return new RecallMetric.Result(
|
return new RecallMetric.Result(
|
||||||
Stream
|
Stream
|
||||||
.generate(() -> randomDouble())
|
.generate(() -> randomDouble())
|
|
@ -0,0 +1,96 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||||
|
|
||||||
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.doReturn;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public final class MockAggregations {
|
||||||
|
|
||||||
|
public static Terms mockTerms(String name) {
|
||||||
|
return mockTerms(name, Collections.emptyList(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
|
||||||
|
Terms agg = mock(Terms.class);
|
||||||
|
when(agg.getName()).thenReturn(name);
|
||||||
|
doReturn(buckets).when(agg).getBuckets();
|
||||||
|
when(agg.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
|
||||||
|
return agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
|
||||||
|
Terms.Bucket bucket = mock(Terms.Bucket.class);
|
||||||
|
when(bucket.getKeyAsString()).thenReturn(key);
|
||||||
|
when(bucket.getAggregations()).thenReturn(subAggs);
|
||||||
|
return bucket;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Filters mockFilters(String name) {
|
||||||
|
return mockFilters(name, Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
|
||||||
|
Filters agg = mock(Filters.class);
|
||||||
|
when(agg.getName()).thenReturn(name);
|
||||||
|
doReturn(buckets).when(agg).getBuckets();
|
||||||
|
return agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
|
||||||
|
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
|
||||||
|
when(bucket.getAggregations()).thenReturn(subAggs);
|
||||||
|
return bucket;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Filters.Bucket mockFiltersBucket(String key, long docCount) {
|
||||||
|
Filters.Bucket bucket = mock(Filters.Bucket.class);
|
||||||
|
when(bucket.getKeyAsString()).thenReturn(key);
|
||||||
|
when(bucket.getDocCount()).thenReturn(docCount);
|
||||||
|
return bucket;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Filter mockFilter(String name, long docCount) {
|
||||||
|
Filter agg = mock(Filter.class);
|
||||||
|
when(agg.getName()).thenReturn(name);
|
||||||
|
when(agg.getDocCount()).thenReturn(docCount);
|
||||||
|
return agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static NumericMetricsAggregation.SingleValue mockSingleValue(String name, double value) {
|
||||||
|
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
||||||
|
when(agg.getName()).thenReturn(name);
|
||||||
|
when(agg.value()).thenReturn(value);
|
||||||
|
return agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Cardinality mockCardinality(String name, long value) {
|
||||||
|
Cardinality agg = mock(Cardinality.class);
|
||||||
|
when(agg.getName()).thenReturn(name);
|
||||||
|
when(agg.getValue()).thenReturn(value);
|
||||||
|
return agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ExtendedStats mockExtendedStats(String name, double variance, long count) {
|
||||||
|
ExtendedStats agg = mock(ExtendedStats.class);
|
||||||
|
when(agg.getName()).thenReturn(name);
|
||||||
|
when(agg.getVariance()).thenReturn(variance);
|
||||||
|
when(agg.getCount()).thenReturn(count);
|
||||||
|
return agg;
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,17 +8,15 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
|
|
||||||
|
@ -48,9 +46,9 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
|
|
||||||
public void testProcess() {
|
public void testProcess() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createTermsAgg("classification_classes"),
|
mockTerms("classification_classes"),
|
||||||
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
|
mockSingleValue("classification_overall_accuracy", 0.8123),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
|
@ -62,16 +60,16 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
public void testProcess_GivenMissingAgg() {
|
public void testProcess_GivenMissingAgg() {
|
||||||
{
|
{
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createTermsAgg("classification_classes"),
|
mockTerms("classification_classes"),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
|
mockSingleValue("classification_overall_accuracy", 0.8123),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
||||||
|
@ -81,32 +79,19 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
public void testProcess_GivenAggOfWrongType() {
|
public void testProcess_GivenAggOfWrongType() {
|
||||||
{
|
{
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createTermsAgg("classification_classes"),
|
mockTerms("classification_classes"),
|
||||||
createTermsAgg("classification_overall_accuracy")
|
mockTerms("classification_overall_accuracy")
|
||||||
));
|
));
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("classification_classes", 1.0),
|
mockSingleValue("classification_classes", 1.0),
|
||||||
createSingleMetricAgg("classification_overall_accuracy", 0.8123)
|
mockSingleValue("classification_overall_accuracy", 0.8123)
|
||||||
));
|
));
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
|
||||||
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.value()).thenReturn(value);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Terms createTermsAgg(String name) {
|
|
||||||
Terms agg = mock(Terms.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,7 +52,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||||
|
|
||||||
public static Classification createRandom() {
|
public static Classification createRandom() {
|
||||||
List<ClassificationMetric> metrics =
|
List<ClassificationMetric> metrics =
|
||||||
randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom()));
|
randomSubsetOf(
|
||||||
|
Arrays.asList(
|
||||||
|
AccuracyTests.createRandom(),
|
||||||
|
MulticlassConfusionMatrixTests.createRandom()));
|
||||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,6 @@ import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
|
||||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|
||||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||||
|
@ -21,15 +18,17 @@ import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
|
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.not;
|
import static org.hamcrest.Matchers.not;
|
||||||
import static org.mockito.Mockito.doReturn;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
||||||
|
|
||||||
|
@ -77,7 +76,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
||||||
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
|
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
|
||||||
assertThat(aggs, is(not(empty())));
|
assertThat(aggs, is(not(empty())));
|
||||||
assertThat(confusionMatrix.getResult(), equalTo(Optional.empty()));
|
assertThat(confusionMatrix.getResult(), isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
|
@ -163,46 +162,4 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
|
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
|
||||||
assertThat(result.getOtherActualClassCount(), equalTo(3L));
|
assertThat(result.getOtherActualClassCount(), equalTo(3L));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
|
|
||||||
Terms aggregation = mock(Terms.class);
|
|
||||||
when(aggregation.getName()).thenReturn(name);
|
|
||||||
doReturn(buckets).when(aggregation).getBuckets();
|
|
||||||
when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
|
|
||||||
return aggregation;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
|
|
||||||
Terms.Bucket bucket = mock(Terms.Bucket.class);
|
|
||||||
when(bucket.getKeyAsString()).thenReturn(key);
|
|
||||||
when(bucket.getAggregations()).thenReturn(subAggs);
|
|
||||||
return bucket;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
|
|
||||||
Filters aggregation = mock(Filters.class);
|
|
||||||
when(aggregation.getName()).thenReturn(name);
|
|
||||||
doReturn(buckets).when(aggregation).getBuckets();
|
|
||||||
return aggregation;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
|
|
||||||
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
|
|
||||||
when(bucket.getAggregations()).thenReturn(subAggs);
|
|
||||||
return bucket;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Filters.Bucket mockFiltersBucket(String key, long docCount) {
|
|
||||||
Filters.Bucket bucket = mock(Filters.Bucket.class);
|
|
||||||
when(bucket.getKeyAsString()).thenReturn(key);
|
|
||||||
when(bucket.getDocCount()).thenReturn(docCount);
|
|
||||||
return bucket;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Cardinality mockCardinality(String name, long value) {
|
|
||||||
Cardinality aggregation = mock(Cardinality.class);
|
|
||||||
when(aggregation.getName()).thenReturn(name);
|
|
||||||
when(aggregation.getValue()).thenReturn(value);
|
|
||||||
return aggregation;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
|
||||||
|
@ -17,9 +16,8 @@ import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
|
public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
|
||||||
|
|
||||||
|
@ -44,8 +42,8 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("regression_mean_squared_error", 0.8123),
|
mockSingleValue("regression_mean_squared_error", 0.8123),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
MeanSquaredError mse = new MeanSquaredError();
|
MeanSquaredError mse = new MeanSquaredError();
|
||||||
|
@ -58,7 +56,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
||||||
|
|
||||||
public void testEvaluate_GivenMissingAggs() {
|
public void testEvaluate_GivenMissingAggs() {
|
||||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
MeanSquaredError mse = new MeanSquaredError();
|
MeanSquaredError mse = new MeanSquaredError();
|
||||||
|
@ -67,11 +65,4 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
||||||
EvaluationMetricResult result = mse.getResult().get();
|
EvaluationMetricResult result = mse.getResult().get();
|
||||||
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
|
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
|
||||||
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.value()).thenReturn(value);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,6 @@ import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
|
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
|
||||||
|
@ -18,9 +16,9 @@ import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockExtendedStats;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
|
@ -45,10 +43,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("residual_sum_of_squares", 10_111),
|
mockSingleValue("residual_sum_of_squares", 10_111),
|
||||||
createExtendedStatsAgg("extended_stats_actual", 155.23, 1000),
|
mockExtendedStats("extended_stats_actual", 155.23, 1000),
|
||||||
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
mockExtendedStats("some_other_extended_stats",99.1, 10_000),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
RSquared rSquared = new RSquared();
|
RSquared rSquared = new RSquared();
|
||||||
|
@ -61,10 +59,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
public void testEvaluateWithZeroCount() {
|
public void testEvaluateWithZeroCount() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("residual_sum_of_squares", 0),
|
mockSingleValue("residual_sum_of_squares", 0),
|
||||||
createExtendedStatsAgg("extended_stats_actual", 0.0, 0),
|
mockExtendedStats("extended_stats_actual", 0.0, 0),
|
||||||
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
mockExtendedStats("some_other_extended_stats",99.1, 10_000),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
RSquared rSquared = new RSquared();
|
RSquared rSquared = new RSquared();
|
||||||
|
@ -76,10 +74,10 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
public void testEvaluateWithSingleCountZeroVariance() {
|
public void testEvaluateWithSingleCountZeroVariance() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("residual_sum_of_squares", 1),
|
mockSingleValue("residual_sum_of_squares", 1),
|
||||||
createExtendedStatsAgg("extended_stats_actual", 0.0, 1),
|
mockExtendedStats("extended_stats_actual", 0.0, 1),
|
||||||
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
mockExtendedStats("some_other_extended_stats",99.1, 10_000),
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
RSquared rSquared = new RSquared();
|
RSquared rSquared = new RSquared();
|
||||||
|
@ -91,7 +89,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
public void testEvaluate_GivenMissingAggs() {
|
public void testEvaluate_GivenMissingAggs() {
|
||||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
RSquared rSquared = new RSquared();
|
RSquared rSquared = new RSquared();
|
||||||
|
@ -103,8 +101,8 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
public void testEvaluate_GivenMissingExtendedStatsAgg() {
|
public void testEvaluate_GivenMissingExtendedStatsAgg() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
mockSingleValue("some_other_single_metric_agg", 0.2377),
|
||||||
createSingleMetricAgg("residual_sum_of_squares", 0.2377)
|
mockSingleValue("residual_sum_of_squares", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
RSquared rSquared = new RSquared();
|
RSquared rSquared = new RSquared();
|
||||||
|
@ -116,8 +114,8 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
|
|
||||||
public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
|
public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
mockSingleValue("some_other_single_metric_agg", 0.2377),
|
||||||
createExtendedStatsAgg("extended_stats_actual",100, 50)
|
mockExtendedStats("extended_stats_actual",100, 50)
|
||||||
));
|
));
|
||||||
|
|
||||||
RSquared rSquared = new RSquared();
|
RSquared rSquared = new RSquared();
|
||||||
|
@ -126,19 +124,4 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
EvaluationMetricResult result = rSquared.getResult().get();
|
EvaluationMetricResult result = rSquared.getResult().get();
|
||||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
|
||||||
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.value()).thenReturn(value);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static ExtendedStats createExtendedStatsAgg(String name, double variance, long count) {
|
|
||||||
ExtendedStats agg = mock(ExtendedStats.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.getVariance()).thenReturn(variance);
|
|
||||||
when(agg.getCount()).thenReturn(count);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
|
||||||
|
@ -18,9 +17,8 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionMatrix> {
|
public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionMatrix> {
|
||||||
|
|
||||||
|
@ -50,14 +48,14 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createFilterAgg("confusion_matrix_at_0.25_TP", 1L),
|
mockFilter("confusion_matrix_at_0.25_TP", 1L),
|
||||||
createFilterAgg("confusion_matrix_at_0.25_FP", 2L),
|
mockFilter("confusion_matrix_at_0.25_FP", 2L),
|
||||||
createFilterAgg("confusion_matrix_at_0.25_TN", 3L),
|
mockFilter("confusion_matrix_at_0.25_TN", 3L),
|
||||||
createFilterAgg("confusion_matrix_at_0.25_FN", 4L),
|
mockFilter("confusion_matrix_at_0.25_FN", 4L),
|
||||||
createFilterAgg("confusion_matrix_at_0.5_TP", 5L),
|
mockFilter("confusion_matrix_at_0.5_TP", 5L),
|
||||||
createFilterAgg("confusion_matrix_at_0.5_FP", 6L),
|
mockFilter("confusion_matrix_at_0.5_FP", 6L),
|
||||||
createFilterAgg("confusion_matrix_at_0.5_TN", 7L),
|
mockFilter("confusion_matrix_at_0.5_TN", 7L),
|
||||||
createFilterAgg("confusion_matrix_at_0.5_FN", 8L)
|
mockFilter("confusion_matrix_at_0.5_FN", 8L)
|
||||||
));
|
));
|
||||||
|
|
||||||
ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
|
ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
|
||||||
|
@ -66,11 +64,4 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
|
||||||
String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
|
String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
|
||||||
assertThat(Strings.toString(result), equalTo(expected));
|
assertThat(Strings.toString(result), equalTo(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Filter createFilterAgg(String name, long docCount) {
|
|
||||||
Filter agg = mock(Filter.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.getDocCount()).thenReturn(docCount);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
|
||||||
|
@ -18,9 +17,8 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
|
|
||||||
|
@ -50,12 +48,12 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createFilterAgg("precision_at_0.25_TP", 1L),
|
mockFilter("precision_at_0.25_TP", 1L),
|
||||||
createFilterAgg("precision_at_0.25_FP", 4L),
|
mockFilter("precision_at_0.25_FP", 4L),
|
||||||
createFilterAgg("precision_at_0.5_TP", 3L),
|
mockFilter("precision_at_0.5_TP", 3L),
|
||||||
createFilterAgg("precision_at_0.5_FP", 1L),
|
mockFilter("precision_at_0.5_FP", 1L),
|
||||||
createFilterAgg("precision_at_0.75_TP", 5L),
|
mockFilter("precision_at_0.75_TP", 5L),
|
||||||
createFilterAgg("precision_at_0.75_FP", 0L)
|
mockFilter("precision_at_0.75_FP", 0L)
|
||||||
));
|
));
|
||||||
|
|
||||||
Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
|
Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
|
||||||
|
@ -67,8 +65,8 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
|
|
||||||
public void testEvaluate_GivenZeroTpAndFp() {
|
public void testEvaluate_GivenZeroTpAndFp() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createFilterAgg("precision_at_1.0_TP", 0L),
|
mockFilter("precision_at_1.0_TP", 0L),
|
||||||
createFilterAgg("precision_at_1.0_FP", 0L)
|
mockFilter("precision_at_1.0_FP", 0L)
|
||||||
));
|
));
|
||||||
|
|
||||||
Precision precision = new Precision(Arrays.asList(1.0));
|
Precision precision = new Precision(Arrays.asList(1.0));
|
||||||
|
@ -77,11 +75,4 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
String expected = "{\"1.0\":0.0}";
|
String expected = "{\"1.0\":0.0}";
|
||||||
assertThat(Strings.toString(result), equalTo(expected));
|
assertThat(Strings.toString(result), equalTo(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Filter createFilterAgg(String name, long docCount) {
|
|
||||||
Filter agg = mock(Filter.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.getDocCount()).thenReturn(docCount);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
|
||||||
|
@ -18,9 +17,8 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilter;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
|
|
||||||
|
@ -50,12 +48,12 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createFilterAgg("recall_at_0.25_TP", 1L),
|
mockFilter("recall_at_0.25_TP", 1L),
|
||||||
createFilterAgg("recall_at_0.25_FN", 4L),
|
mockFilter("recall_at_0.25_FN", 4L),
|
||||||
createFilterAgg("recall_at_0.5_TP", 3L),
|
mockFilter("recall_at_0.5_TP", 3L),
|
||||||
createFilterAgg("recall_at_0.5_FN", 1L),
|
mockFilter("recall_at_0.5_FN", 1L),
|
||||||
createFilterAgg("recall_at_0.75_TP", 5L),
|
mockFilter("recall_at_0.75_TP", 5L),
|
||||||
createFilterAgg("recall_at_0.75_FN", 0L)
|
mockFilter("recall_at_0.75_FN", 0L)
|
||||||
));
|
));
|
||||||
|
|
||||||
Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
|
Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
|
||||||
|
@ -67,8 +65,8 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
|
|
||||||
public void testEvaluate_GivenZeroTpAndFp() {
|
public void testEvaluate_GivenZeroTpAndFp() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
createFilterAgg("recall_at_1.0_TP", 0L),
|
mockFilter("recall_at_1.0_TP", 0L),
|
||||||
createFilterAgg("recall_at_1.0_FN", 0L)
|
mockFilter("recall_at_1.0_FN", 0L)
|
||||||
));
|
));
|
||||||
|
|
||||||
Recall recall = new Recall(Arrays.asList(1.0));
|
Recall recall = new Recall(Arrays.asList(1.0));
|
||||||
|
@ -77,11 +75,4 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
String expected = "{\"1.0\":0.0}";
|
String expected = "{\"1.0\":0.0}";
|
||||||
assertThat(Strings.toString(result), equalTo(expected));
|
assertThat(Strings.toString(result), equalTo(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Filter createFilterAgg(String name, long docCount) {
|
|
||||||
Filter agg = mock(Filter.class);
|
|
||||||
when(agg.getName()).thenReturn(name);
|
|
||||||
when(agg.getDocCount()).thenReturn(docCount);
|
|
||||||
return agg;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue