Perform evaluation in multiple steps when necessary (#53295) (#53409)

This commit is contained in:
Przemysław Witek 2020-03-11 15:36:38 +01:00 committed by GitHub
parent 562a9eff33
commit 8c4c19d310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 327 additions and 66 deletions

View File

@ -66,7 +66,7 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
* Builds the search required to collect data to compute the evaluation result * Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/ */
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder); Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery = BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery() QueryBuilders.boolQuery()
@ -78,7 +78,8 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) { for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics // Fetch aggregations requested by individual metrics
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField()); Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
metric.aggs(parameters, getActualField(), getPredictedField());
aggs.v1().forEach(searchSourceBuilder::aggregation); aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation); aggs.v2().forEach(searchSourceBuilder::aggregation);
} }

View File

@ -28,11 +28,14 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
/** /**
* Builds the aggregation that collect required data to compute the metric * Builds the aggregation that collect required data to compute the metric
* @param parameters settings that may be needed by aggregations
* @param actualField the field that stores the actual value * @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value (class name or probability) * @param predictedField the field that stores the predicted value (class name or probability)
* @return the aggregations required to compute the metric * @return the aggregations required to compute the metric
*/ */
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField); Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField);
/** /**
* Processes given aggregations as a step towards computing result * Processes given aggregations as a step towards computing result

View File

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
/**
* Encapsulates parameters needed by evaluation.
*/
public class EvaluationParameters {
/**
* Maximum number of buckets allowed in any single search request.
*/
private final int maxBuckets;
public EvaluationParameters(int maxBuckets) {
this.maxBuckets = maxBuckets;
}
public int getMaxBuckets() {
return maxBuckets;
}
}

View File

@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -103,7 +104,9 @@ public class Accuracy implements EvaluationMetric {
} }
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}. // Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField); this.actualField.trySet(actualField);
List<AggregationBuilder> aggs = new ArrayList<>(); List<AggregationBuilder> aggs = new ArrayList<>();
@ -112,7 +115,8 @@ public class Accuracy implements EvaluationMetric {
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
} }
if (result.get() == null) { if (result.get() == null) {
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField); Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
matrix.aggs(parameters, actualField, predictedField);
aggs.addAll(matrixAggs.v1()); aggs.addAll(matrixAggs.v1());
pipelineAggs.addAll(matrixAggs.v2()); pipelineAggs.addAll(matrixAggs.v2());
} }

View File

@ -29,6 +29,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality; import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -73,9 +74,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
} }
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_cardinality_of_actual_class";
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
private static final String OTHER_BUCKET_KEY = "_other_"; private static final String OTHER_BUCKET_KEY = "_other_";
private static final String DEFAULT_AGG_NAME_PREFIX = ""; private static final String DEFAULT_AGG_NAME_PREFIX = "";
private static final int DEFAULT_SIZE = 10; private static final int DEFAULT_SIZE = 10;
@ -84,6 +85,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
private final int size; private final int size;
private final String aggNamePrefix; private final String aggNamePrefix;
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>(); private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
private final SetOnce<Long> actualClassesCardinality = new SetOnce<>();
/** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */
private final List<ActualClass> actualClasses = new ArrayList<>();
private final SetOnce<Result> result = new SetOnce<>(); private final SetOnce<Result> result = new SetOnce<>();
public MulticlassConfusionMatrix() { public MulticlassConfusionMatrix() {
@ -122,34 +126,45 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
} }
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
if (topActualClassNames.get() == null) { // This is step 1 String actualField,
String predictedField) {
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
return Tuple.tuple( return Tuple.tuple(
Arrays.asList( Arrays.asList(
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
.field(actualField) .field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)), .size(size),
AggregationBuilders.cardinality(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS))
.field(actualField)),
Collections.emptyList()); Collections.emptyList());
} }
if (result.get() == null) { // This is step 2 if (result.get() == null) { // These are steps 2, 3, 4 etc.
KeyedFilter[] keyedFiltersActual =
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
KeyedFilter[] keyedFiltersPredicted = KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.get().stream() topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new); .toArray(KeyedFilter[]::new);
return Tuple.tuple( // Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
Arrays.asList( // too_many_buckets_exception exception is not thrown.
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)) // The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
.field(actualField), // In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual) int actualClassesPerBatch = Math.max(parameters.getMaxBuckets() / (topActualClassNames.get().size() + 2), 1);
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted) KeyedFilter[] keyedFiltersActual =
.otherBucket(true) topActualClassNames.get().stream()
.otherBucketKey(OTHER_BUCKET_KEY))), .skip(actualClasses.size())
Collections.emptyList()); .limit(actualClassesPerBatch)
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
if (keyedFiltersActual.length > 0) {
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))),
Collections.emptyList());
}
} }
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }
@ -160,10 +175,12 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)); Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList())); topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
} }
if (actualClassesCardinality.get() == null && aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) != null) {
Cardinality cardinalityAgg = aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS));
actualClassesCardinality.set(cardinalityAgg.getValue());
}
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) { if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)); Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
for (Filters.Bucket bucket : filtersAgg.getBuckets()) { for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString(); String actualClass = bucket.getKeyAsString();
long actualClassDocCount = bucket.getDocCount(); long actualClassDocCount = bucket.getDocCount();
@ -182,7 +199,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
} }
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0))); if (actualClasses.size() == topActualClassNames.get().size()) {
result.set(new Result(actualClasses, Math.max(actualClassesCardinality.get() - size, 0)));
}
} }
} }

View File

@ -30,6 +30,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -97,7 +98,9 @@ public class Precision implements EvaluationMetric {
} }
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}. // Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField); this.actualField.trySet(actualField);
if (topActualClassNames.get() == null) { // This is step 1 if (topActualClassNames.get() == null) { // This is step 1

View File

@ -26,6 +26,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -90,7 +91,9 @@ public class Recall implements EvaluationMetric {
} }
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}. // Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField); this.actualField.trySet(actualField);
if (result.get() != null) { if (result.get() != null) {

View File

@ -20,6 +20,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException; import java.io.IOException;
import java.text.MessageFormat; import java.text.MessageFormat;
@ -67,7 +68,9 @@ public class MeanSquaredError implements EvaluationMetric {
} }
@Override @Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (result != null) { if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }

View File

@ -22,6 +22,7 @@ import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBui
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException; import java.io.IOException;
import java.text.MessageFormat; import java.text.MessageFormat;
@ -72,7 +73,9 @@ public class RSquared implements EvaluationMetric {
} }
@Override @Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (result != null) { if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }

View File

@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -66,7 +67,9 @@ abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
} }
@Override @Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) { public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedProbabilityField) {
if (result != null) { if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }

View File

@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Percentiles; import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -127,7 +128,9 @@ public class AucRoc implements EvaluationMetric {
} }
@Override @Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) { public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedProbabilityField) {
if (result != null) { if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }

View File

@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.equalTo;
public class EvaluationParametersTests extends ESTestCase {
public void testConstructorAndGetters() {
EvaluationParameters params = new EvaluationParameters(17);
assertThat(params.getMaxBuckets(), equalTo(17));
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
@ -17,19 +18,21 @@ import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> { public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected Accuracy doParseInstance(XContentParser parser) throws IOException { protected Accuracy doParseInstance(XContentParser parser) throws IOException {
return Accuracy.fromXContent(parser); return Accuracy.fromXContent(parser);
@ -62,6 +65,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L), 100L),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
mockFilters( mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
@ -79,13 +83,12 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
accuracy.process(aggs); accuracy.process(aggs);
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
Result result = accuracy.getResult().get(); Result result = accuracy.getResult().get();
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
@ -106,6 +109,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L), 100L),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
mockFilters( mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
@ -123,11 +127,10 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
accuracy.aggs("foo", "bar"); accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
} }

View File

@ -25,6 +25,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException; import java.io.IOException;
@ -43,6 +44,8 @@ import static org.mockito.Mockito.when;
public class ClassificationTests extends AbstractSerializingTestCase<Classification> { public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
@ -100,7 +103,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix())); Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
} }
@ -196,7 +199,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
} }
@Override @Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }

View File

@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
@ -23,18 +24,20 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> { public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException { protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrix.fromXContent(parser); return MulticlassConfusionMatrix.fromXContent(parser);
@ -80,12 +83,12 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testAggs() { public void testAggs() {
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs("act", "pred"); Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred");
assertThat(aggs, isTuple(not(empty()), empty())); assertThat(aggs, isTuple(not(empty()), empty()));
assertThat(confusionMatrix.getResult(), isEmpty()); assertThat(confusionMatrix.getResult(), isEmpty());
} }
public void testEvaluate() { public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms( mockTerms(
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
@ -93,6 +96,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
0L), 0L),
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L),
mockFilters( mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
@ -109,13 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))))
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L))); ));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs); confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
Result result = confusionMatrix.getResult().get(); Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat( assertThat(
@ -127,7 +131,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
assertThat(result.getOtherActualClassCount(), equalTo(0L)); assertThat(result.getOtherActualClassCount(), equalTo(0L));
} }
public void testEvaluate_OtherClassesCountGreaterThanZero() { public void testProcess_OtherClassesCountGreaterThanZero() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms( mockTerms(
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
@ -135,6 +139,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L), 100L),
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 5L),
mockFilters( mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
@ -151,13 +156,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))))
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L))); ));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs); confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
Result result = confusionMatrix.getResult().get(); Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat( assertThat(
@ -168,4 +173,106 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
assertThat(result.getOtherActualClassCount(), equalTo(3L)); assertThat(result.getOtherActualClassCount(), equalTo(3L));
} }
public void testProcess_MoreThanTwoStepsNeeded() {
Aggregations aggsStep1 = new Aggregations(Arrays.asList(
mockTerms(
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockTermsBucket("ant", new Aggregations(Arrays.asList())),
mockTermsBucket("cat", new Aggregations(Arrays.asList())),
mockTermsBucket("dog", new Aggregations(Arrays.asList())),
mockTermsBucket("fox", new Aggregations(Arrays.asList()))),
0L),
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L)
));
Aggregations aggsStep2 = new Aggregations(Arrays.asList(
mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"ant",
46,
new Aggregations(Arrays.asList(mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("ant", 10L),
mockFiltersBucket("cat", 11L),
mockFiltersBucket("dog", 12L),
mockFiltersBucket("fox", 13L),
mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
86,
new Aggregations(Arrays.asList(mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("ant", 20L),
mockFiltersBucket("cat", 21L),
mockFiltersBucket("dog", 22L),
mockFiltersBucket("fox", 23L),
mockFiltersBucket("_other_", 0L))))))))
));
Aggregations aggsStep3 = new Aggregations(Arrays.asList(
mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"dog",
126,
new Aggregations(Arrays.asList(mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("ant", 30L),
mockFiltersBucket("cat", 31L),
mockFiltersBucket("dog", 32L),
mockFiltersBucket("fox", 33L),
mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"fox",
166,
new Aggregations(Arrays.asList(mockFilters(
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("ant", 40L),
mockFiltersBucket("cat", 41L),
mockFiltersBucket("dog", 42L),
mockFiltersBucket("fox", 43L),
mockFiltersBucket("_other_", 0L))))))))
));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(4, null);
confusionMatrix.process(aggsStep1);
confusionMatrix.process(aggsStep2);
confusionMatrix.process(aggsStep3);
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
result.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass("ant", 46, Arrays.asList(
new PredictedClass("ant", 10L),
new PredictedClass("cat", 11L),
new PredictedClass("dog", 12L),
new PredictedClass("fox", 13L)), 0),
new ActualClass("cat", 86, Arrays.asList(
new PredictedClass("ant", 20L),
new PredictedClass("cat", 21L),
new PredictedClass("dog", 22L),
new PredictedClass("fox", 23L)), 0),
new ActualClass("dog", 126, Arrays.asList(
new PredictedClass("ant", 30L),
new PredictedClass("cat", 31L),
new PredictedClass("dog", 32L),
new PredictedClass("fox", 33L)), 0),
new ActualClass("fox", 166, Arrays.asList(
new PredictedClass("ant", 40L),
new PredictedClass("cat", 41L),
new PredictedClass("dog", 42L),
new PredictedClass("fox", 43L)), 0))));
assertThat(result.getOtherActualClassCount(), equalTo(0L));
}
} }

View File

@ -10,22 +10,25 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class PrecisionTests extends AbstractSerializingTestCase<Precision> { public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected Precision doParseInstance(XContentParser parser) throws IOException { protected Precision doParseInstance(XContentParser parser) throws IOException {
return Precision.fromXContent(parser); return Precision.fromXContent(parser);
@ -61,7 +64,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
Precision precision = new Precision(); Precision precision = new Precision();
precision.process(aggs); precision.process(aggs);
assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123))); assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
} }
@ -111,7 +114,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
Aggregations aggs = Aggregations aggs =
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1))); new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
Precision precision = new Precision(); Precision precision = new Precision();
precision.aggs("foo", "bar"); precision.aggs(EVALUATION_PARAMETERS, "foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs)); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
} }

View File

@ -10,21 +10,24 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class RecallTests extends AbstractSerializingTestCase<Recall> { public class RecallTests extends AbstractSerializingTestCase<Recall> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected Recall doParseInstance(XContentParser parser) throws IOException { protected Recall doParseInstance(XContentParser parser) throws IOException {
return Recall.fromXContent(parser); return Recall.fromXContent(parser);
@ -59,7 +62,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
Recall recall = new Recall(); Recall recall = new Recall();
recall.process(aggs); recall.process(aggs);
assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123))); assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
} }
@ -110,7 +113,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1), mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123))); mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
Recall recall = new Recall(); Recall recall = new Recall();
recall.aggs("foo", "bar"); recall.aggs(EVALUATION_PARAMETERS, "foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs)); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException; import java.io.IOException;
@ -28,6 +29,8 @@ import static org.hamcrest.Matchers.greaterThan;
public class RegressionTests extends AbstractSerializingTestCase<Regression> { public class RegressionTests extends AbstractSerializingTestCase<Regression> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
@ -85,7 +88,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError())); Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException; import java.io.IOException;
@ -28,6 +29,8 @@ import static org.hamcrest.Matchers.greaterThan;
public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> { public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
@Override @Override
protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables()); return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
@ -98,7 +101,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7)))); BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
} }

View File

@ -5,18 +5,21 @@
*/ */
package org.elasticsearch.xpack.ml.integration; package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.aggregations.MultiBucketConsumerService.TooManyBucketsException;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -28,6 +31,8 @@ import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@ -49,6 +54,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
@After @After
public void cleanup() { public void cleanup() {
cleanUp(); cleanUp();
client().admin().cluster()
.prepareUpdateSettings()
.setTransientSettings(Settings.builder().putNull("search.max_buckets"))
.get();
} }
public void testEvaluate_DefaultMetrics() { public void testEvaluate_DefaultMetrics() {
@ -208,7 +217,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high")); assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
} }
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() { private void evaluateWithMulticlassConfusionMatrix() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse = EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame( evaluateDataFrame(
ANIMALS_DATA_INDEX, ANIMALS_DATA_INDEX,
@ -271,6 +280,23 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
} }
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
evaluateWithMulticlassConfusionMatrix();
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get();
evaluateWithMulticlassConfusionMatrix();
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get();
evaluateWithMulticlassConfusionMatrix();
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get();
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix);
assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class)));
TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause();
assertThat(tmbe.getMaxBuckets(), equalTo(6));
}
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() { public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse = EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame( evaluateDataFrame(

View File

@ -11,6 +11,7 @@ import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
@ -18,26 +19,41 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request, public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
EvaluateDataFrameAction.Response> { EvaluateDataFrameAction.Response> {
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final Client client; private final Client client;
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
@Inject @Inject
public TransportEvaluateDataFrameAction(TransportService transportService, ActionFilters actionFilters, ThreadPool threadPool, public TransportEvaluateDataFrameAction(TransportService transportService,
Client client) { ActionFilters actionFilters,
ThreadPool threadPool,
Client client,
ClusterService clusterService) {
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new); super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
this.threadPool = threadPool; this.threadPool = threadPool;
this.client = client; this.client = client;
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
}
private void setMaxBuckets(int maxBuckets) {
this.maxBuckets.set(maxBuckets);
} }
@Override @Override
protected void doExecute(Task task, EvaluateDataFrameAction.Request request, protected void doExecute(Task task,
EvaluateDataFrameAction.Request request,
ActionListener<EvaluateDataFrameAction.Response> listener) { ActionListener<EvaluateDataFrameAction.Response> listener) {
ActionListener<List<Void>> resultsListener = ActionListener.wrap( ActionListener<List<Void>> resultsListener = ActionListener.wrap(
unused -> { unused -> {
@ -48,7 +64,9 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
listener::onFailure listener::onFailure
); );
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request); // Create an immutable collection of parameters to be used by evaluation metrics.
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request);
evaluationExecutor.execute(resultsListener); evaluationExecutor.execute(resultsListener);
} }
@ -68,12 +86,14 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> { private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
private final Client client; private final Client client;
private final EvaluationParameters parameters;
private final EvaluateDataFrameAction.Request request; private final EvaluateDataFrameAction.Request request;
private final Evaluation evaluation; private final Evaluation evaluation;
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) { EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) {
super(threadPool.generic(), unused -> true, unused -> true); super(threadPool.generic(), unused -> true, unused -> true);
this.client = client; this.client = client;
this.parameters = parameters;
this.request = request; this.request = request;
this.evaluation = request.getEvaluation(); this.evaluation = request.getEvaluation();
// Add one task only. Other tasks will be added as needed by the nextTask method itself. // Add one task only. Other tasks will be added as needed by the nextTask method itself.
@ -82,7 +102,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
private TypedChainTaskExecutor.ChainTask<Void> nextTask() { private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
return listener -> { return listener -> {
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery()); SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder); SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
client.execute( client.execute(
SearchAction.INSTANCE, SearchAction.INSTANCE,