parent
562a9eff33
commit
8c4c19d310
|
@ -66,7 +66,7 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
* Builds the search required to collect data to compute the evaluation result
|
||||
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
|
||||
*/
|
||||
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
|
||||
Objects.requireNonNull(userProvidedQueryBuilder);
|
||||
BoolQueryBuilder boolQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
|
@ -78,7 +78,8 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (EvaluationMetric metric : getMetrics()) {
|
||||
// Fetch aggregations requested by individual metrics
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
|
||||
metric.aggs(parameters, getActualField(), getPredictedField());
|
||||
aggs.v1().forEach(searchSourceBuilder::aggregation);
|
||||
aggs.v2().forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
|
|
|
@ -28,11 +28,14 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
|||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param parameters settings that may be needed by aggregations
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value (class name or probability)
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
/**
|
||||
* Encapsulates parameters needed by evaluation.
|
||||
*/
|
||||
public class EvaluationParameters {
|
||||
|
||||
/**
|
||||
* Maximum number of buckets allowed in any single search request.
|
||||
*/
|
||||
private final int maxBuckets;
|
||||
|
||||
public EvaluationParameters(int maxBuckets) {
|
||||
this.maxBuckets = maxBuckets;
|
||||
}
|
||||
|
||||
public int getMaxBuckets() {
|
||||
return maxBuckets;
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -103,7 +104,9 @@ public class Accuracy implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
|
@ -112,7 +115,8 @@ public class Accuracy implements EvaluationMetric {
|
|||
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
||||
}
|
||||
if (result.get() == null) {
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
|
||||
matrix.aggs(parameters, actualField, predictedField);
|
||||
aggs.addAll(matrixAggs.v1());
|
||||
pipelineAggs.addAll(matrixAggs.v2());
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -73,9 +74,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
}
|
||||
|
||||
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
||||
static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_cardinality_of_actual_class";
|
||||
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
||||
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
||||
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
||||
private static final String OTHER_BUCKET_KEY = "_other_";
|
||||
private static final String DEFAULT_AGG_NAME_PREFIX = "";
|
||||
private static final int DEFAULT_SIZE = 10;
|
||||
|
@ -84,6 +85,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
private final int size;
|
||||
private final String aggNamePrefix;
|
||||
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
||||
private final SetOnce<Long> actualClassesCardinality = new SetOnce<>();
|
||||
/** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */
|
||||
private final List<ActualClass> actualClasses = new ArrayList<>();
|
||||
private final SetOnce<Result> result = new SetOnce<>();
|
||||
|
||||
public MulticlassConfusionMatrix() {
|
||||
|
@ -122,35 +126,46 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (topActualClassNames.get() == null) { // This is step 1
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size)),
|
||||
.size(size),
|
||||
AggregationBuilders.cardinality(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS))
|
||||
.field(actualField)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
if (result.get() == null) { // This is step 2
|
||||
KeyedFilter[] keyedFiltersActual =
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
if (result.get() == null) { // These are steps 2, 3, 4 etc.
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
|
||||
// too_many_buckets_exception exception is not thrown.
|
||||
// The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
|
||||
// In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
|
||||
int actualClassesPerBatch = Math.max(parameters.getMaxBuckets() / (topActualClassNames.get().size() + 2), 1);
|
||||
KeyedFilter[] keyedFiltersActual =
|
||||
topActualClassNames.get().stream()
|
||||
.skip(actualClasses.size())
|
||||
.limit(actualClassesPerBatch)
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
if (keyedFiltersActual.length > 0) {
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
|
||||
.field(actualField),
|
||||
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
|
||||
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
|
||||
.otherBucket(true)
|
||||
.otherBucketKey(OTHER_BUCKET_KEY))),
|
||||
Collections.emptyList());
|
||||
}
|
||||
}
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
|
@ -160,10 +175,12 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
|
||||
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
||||
}
|
||||
if (actualClassesCardinality.get() == null && aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) != null) {
|
||||
Cardinality cardinalityAgg = aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS));
|
||||
actualClassesCardinality.set(cardinalityAgg.getValue());
|
||||
}
|
||||
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
||||
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
|
||||
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
|
||||
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
|
||||
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
||||
String actualClass = bucket.getKeyAsString();
|
||||
long actualClassDocCount = bucket.getDocCount();
|
||||
|
@ -182,7 +199,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
||||
}
|
||||
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
|
||||
if (actualClasses.size() == topActualClassNames.get().size()) {
|
||||
result.set(new Result(actualClasses, Math.max(actualClassesCardinality.get() - size, 0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -97,7 +98,9 @@ public class Precision implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
if (topActualClassNames.get() == null) { // This is step 1
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -90,7 +91,9 @@ public class Recall implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
if (result.get() != null) {
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
|
@ -67,7 +68,9 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBui
|
|||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
|
@ -72,7 +73,9 @@ public class RSquared implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
|
|||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -66,7 +67,9 @@ abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
|||
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -127,7 +128,9 @@ public class AucRoc implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class EvaluationParametersTests extends ESTestCase {
|
||||
|
||||
public void testConstructorAndGetters() {
|
||||
EvaluationParameters params = new EvaluationParameters(17);
|
||||
assertThat(params.getMaxBuckets(), equalTo(17));
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||
|
||||
|
@ -17,19 +18,21 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
|
||||
return Accuracy.fromXContent(parser);
|
||||
|
@ -62,6 +65,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
|
||||
mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
|
@ -79,13 +83,12 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
|
||||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||
|
||||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.process(aggs);
|
||||
|
||||
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
|
||||
Result result = accuracy.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
|
@ -106,6 +109,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
|
||||
mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
|
@ -123,11 +127,10 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
|
||||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||
|
||||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.aggs("foo", "bar");
|
||||
accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
|
|||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -43,6 +44,8 @@ import static org.mockito.Mockito.when;
|
|||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
|
@ -100,7 +103,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
|
||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
@ -196,7 +199,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
|
||||
|
@ -23,18 +24,20 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrix.fromXContent(parser);
|
||||
|
@ -80,12 +83,12 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
|
||||
public void testAggs() {
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs("act", "pred");
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred");
|
||||
assertThat(aggs, isTuple(not(empty()), empty()));
|
||||
assertThat(confusionMatrix.getResult(), isEmpty());
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
public void testProcess() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
|
@ -93,6 +96,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
0L),
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L),
|
||||
mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
|
@ -109,13 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
new Aggregations(Arrays.asList(mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))))
|
||||
));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -127,7 +131,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
assertThat(result.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
||||
public void testProcess_OtherClassesCountGreaterThanZero() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
|
@ -135,6 +139,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 5L),
|
||||
mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
|
@ -151,13 +156,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
new Aggregations(Arrays.asList(mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))))
|
||||
));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -168,4 +173,106 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
|
||||
assertThat(result.getOtherActualClassCount(), equalTo(3L));
|
||||
}
|
||||
|
||||
public void testProcess_MoreThanTwoStepsNeeded() {
|
||||
Aggregations aggsStep1 = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockTermsBucket("ant", new Aggregations(Arrays.asList())),
|
||||
mockTermsBucket("cat", new Aggregations(Arrays.asList())),
|
||||
mockTermsBucket("dog", new Aggregations(Arrays.asList())),
|
||||
mockTermsBucket("fox", new Aggregations(Arrays.asList()))),
|
||||
0L),
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L)
|
||||
));
|
||||
Aggregations aggsStep2 = new Aggregations(Arrays.asList(
|
||||
mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket(
|
||||
"ant",
|
||||
46,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("ant", 10L),
|
||||
mockFiltersBucket("cat", 11L),
|
||||
mockFiltersBucket("dog", 12L),
|
||||
mockFiltersBucket("fox", 13L),
|
||||
mockFiltersBucket("_other_", 0L)))))),
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
86,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("ant", 20L),
|
||||
mockFiltersBucket("cat", 21L),
|
||||
mockFiltersBucket("dog", 22L),
|
||||
mockFiltersBucket("fox", 23L),
|
||||
mockFiltersBucket("_other_", 0L))))))))
|
||||
));
|
||||
Aggregations aggsStep3 = new Aggregations(Arrays.asList(
|
||||
mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
126,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("ant", 30L),
|
||||
mockFiltersBucket("cat", 31L),
|
||||
mockFiltersBucket("dog", 32L),
|
||||
mockFiltersBucket("fox", 33L),
|
||||
mockFiltersBucket("_other_", 0L)))))),
|
||||
mockFiltersBucket(
|
||||
"fox",
|
||||
166,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("ant", 40L),
|
||||
mockFiltersBucket("cat", 41L),
|
||||
mockFiltersBucket("dog", 42L),
|
||||
mockFiltersBucket("fox", 43L),
|
||||
mockFiltersBucket("_other_", 0L))))))))
|
||||
));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(4, null);
|
||||
confusionMatrix.process(aggsStep1);
|
||||
confusionMatrix.process(aggsStep2);
|
||||
confusionMatrix.process(aggsStep3);
|
||||
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
result.getConfusionMatrix(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new ActualClass("ant", 46, Arrays.asList(
|
||||
new PredictedClass("ant", 10L),
|
||||
new PredictedClass("cat", 11L),
|
||||
new PredictedClass("dog", 12L),
|
||||
new PredictedClass("fox", 13L)), 0),
|
||||
new ActualClass("cat", 86, Arrays.asList(
|
||||
new PredictedClass("ant", 20L),
|
||||
new PredictedClass("cat", 21L),
|
||||
new PredictedClass("dog", 22L),
|
||||
new PredictedClass("fox", 23L)), 0),
|
||||
new ActualClass("dog", 126, Arrays.asList(
|
||||
new PredictedClass("ant", 30L),
|
||||
new PredictedClass("cat", 31L),
|
||||
new PredictedClass("dog", 32L),
|
||||
new PredictedClass("fox", 33L)), 0),
|
||||
new ActualClass("fox", 166, Arrays.asList(
|
||||
new PredictedClass("ant", 40L),
|
||||
new PredictedClass("cat", 41L),
|
||||
new PredictedClass("dog", 42L),
|
||||
new PredictedClass("fox", 43L)), 0))));
|
||||
assertThat(result.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,22 +10,25 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected Precision doParseInstance(XContentParser parser) throws IOException {
|
||||
return Precision.fromXContent(parser);
|
||||
|
@ -61,7 +64,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
|||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
|
||||
assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
|
@ -111,7 +114,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
|||
Aggregations aggs =
|
||||
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
|
||||
Precision precision = new Precision();
|
||||
precision.aggs("foo", "bar");
|
||||
precision.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
|
|
@ -10,21 +10,24 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected Recall doParseInstance(XContentParser parser) throws IOException {
|
||||
return Recall.fromXContent(parser);
|
||||
|
@ -59,7 +62,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
|||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
|
||||
assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
|
@ -110,7 +113,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
|||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
|
||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
|
||||
Recall recall = new Recall();
|
||||
recall.aggs("foo", "bar");
|
||||
recall.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
|||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -28,6 +29,8 @@ import static org.hamcrest.Matchers.greaterThan;
|
|||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
|
@ -85,7 +88,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
|
||||
Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
|||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -28,6 +29,8 @@ import static org.hamcrest.Matchers.greaterThan;
|
|||
|
||||
public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
|
@ -98,7 +101,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
|||
|
||||
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
|
|
@ -5,18 +5,21 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.search.aggregations.MultiBucketConsumerService.TooManyBucketsException;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -28,6 +31,8 @@ import static org.hamcrest.Matchers.contains;
|
|||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
|
@ -49,6 +54,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
client().admin().cluster()
|
||||
.prepareUpdateSettings()
|
||||
.setTransientSettings(Settings.builder().putNull("search.max_buckets"))
|
||||
.get();
|
||||
}
|
||||
|
||||
public void testEvaluate_DefaultMetrics() {
|
||||
|
@ -208,7 +217,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
}
|
||||
|
||||
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
||||
private void evaluateWithMulticlassConfusionMatrix() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
|
@ -271,6 +280,23 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
||||
evaluateWithMulticlassConfusionMatrix();
|
||||
|
||||
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get();
|
||||
evaluateWithMulticlassConfusionMatrix();
|
||||
|
||||
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get();
|
||||
evaluateWithMulticlassConfusionMatrix();
|
||||
|
||||
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get();
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix);
|
||||
|
||||
assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class)));
|
||||
TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause();
|
||||
assertThat(tmbe.getMaxBuckets(), equalTo(6));
|
||||
}
|
||||
|
||||
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.action.search.SearchRequest;
|
|||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
|
@ -18,26 +19,41 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
|
||||
|
||||
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
|
||||
EvaluateDataFrameAction.Response> {
|
||||
|
||||
private final ThreadPool threadPool;
|
||||
private final Client client;
|
||||
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
|
||||
|
||||
@Inject
|
||||
public TransportEvaluateDataFrameAction(TransportService transportService, ActionFilters actionFilters, ThreadPool threadPool,
|
||||
Client client) {
|
||||
public TransportEvaluateDataFrameAction(TransportService transportService,
|
||||
ActionFilters actionFilters,
|
||||
ThreadPool threadPool,
|
||||
Client client,
|
||||
ClusterService clusterService) {
|
||||
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
|
||||
this.threadPool = threadPool;
|
||||
this.client = client;
|
||||
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
|
||||
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
|
||||
}
|
||||
|
||||
private void setMaxBuckets(int maxBuckets) {
|
||||
this.maxBuckets.set(maxBuckets);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doExecute(Task task, EvaluateDataFrameAction.Request request,
|
||||
protected void doExecute(Task task,
|
||||
EvaluateDataFrameAction.Request request,
|
||||
ActionListener<EvaluateDataFrameAction.Response> listener) {
|
||||
ActionListener<List<Void>> resultsListener = ActionListener.wrap(
|
||||
unused -> {
|
||||
|
@ -48,7 +64,9 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
listener::onFailure
|
||||
);
|
||||
|
||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request);
|
||||
// Create an immutable collection of parameters to be used by evaluation metrics.
|
||||
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
|
||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request);
|
||||
evaluationExecutor.execute(resultsListener);
|
||||
}
|
||||
|
||||
|
@ -68,12 +86,14 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
|
||||
|
||||
private final Client client;
|
||||
private final EvaluationParameters parameters;
|
||||
private final EvaluateDataFrameAction.Request request;
|
||||
private final Evaluation evaluation;
|
||||
|
||||
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) {
|
||||
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) {
|
||||
super(threadPool.generic(), unused -> true, unused -> true);
|
||||
this.client = client;
|
||||
this.parameters = parameters;
|
||||
this.request = request;
|
||||
this.evaluation = request.getEvaluation();
|
||||
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
||||
|
@ -82,7 +102,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
|
||||
private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
|
||||
return listener -> {
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery());
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
|
||||
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
||||
client.execute(
|
||||
SearchAction.INSTANCE,
|
||||
|
|
Loading…
Reference in New Issue