parent
562a9eff33
commit
8c4c19d310
|
@ -66,7 +66,7 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
||||||
* Builds the search required to collect data to compute the evaluation result
|
* Builds the search required to collect data to compute the evaluation result
|
||||||
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
|
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
|
||||||
*/
|
*/
|
||||||
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
|
||||||
Objects.requireNonNull(userProvidedQueryBuilder);
|
Objects.requireNonNull(userProvidedQueryBuilder);
|
||||||
BoolQueryBuilder boolQuery =
|
BoolQueryBuilder boolQuery =
|
||||||
QueryBuilders.boolQuery()
|
QueryBuilders.boolQuery()
|
||||||
|
@ -78,7 +78,8 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
||||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||||
for (EvaluationMetric metric : getMetrics()) {
|
for (EvaluationMetric metric : getMetrics()) {
|
||||||
// Fetch aggregations requested by individual metrics
|
// Fetch aggregations requested by individual metrics
|
||||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
|
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
|
||||||
|
metric.aggs(parameters, getActualField(), getPredictedField());
|
||||||
aggs.v1().forEach(searchSourceBuilder::aggregation);
|
aggs.v1().forEach(searchSourceBuilder::aggregation);
|
||||||
aggs.v2().forEach(searchSourceBuilder::aggregation);
|
aggs.v2().forEach(searchSourceBuilder::aggregation);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,11 +28,14 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the aggregation that collect required data to compute the metric
|
* Builds the aggregation that collect required data to compute the metric
|
||||||
|
* @param parameters settings that may be needed by aggregations
|
||||||
* @param actualField the field that stores the actual value
|
* @param actualField the field that stores the actual value
|
||||||
* @param predictedField the field that stores the predicted value (class name or probability)
|
* @param predictedField the field that stores the predicted value (class name or probability)
|
||||||
* @return the aggregations required to compute the metric
|
* @return the aggregations required to compute the metric
|
||||||
*/
|
*/
|
||||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
|
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Processes given aggregations as a step towards computing result
|
* Processes given aggregations as a step towards computing result
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encapsulates parameters needed by evaluation.
|
||||||
|
*/
|
||||||
|
public class EvaluationParameters {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maximum number of buckets allowed in any single search request.
|
||||||
|
*/
|
||||||
|
private final int maxBuckets;
|
||||||
|
|
||||||
|
public EvaluationParameters(int maxBuckets) {
|
||||||
|
this.maxBuckets = maxBuckets;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMaxBuckets() {
|
||||||
|
return maxBuckets;
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -103,7 +104,9 @@ public class Accuracy implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||||
this.actualField.trySet(actualField);
|
this.actualField.trySet(actualField);
|
||||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||||
|
@ -112,7 +115,8 @@ public class Accuracy implements EvaluationMetric {
|
||||||
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
||||||
}
|
}
|
||||||
if (result.get() == null) {
|
if (result.get() == null) {
|
||||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
|
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
|
||||||
|
matrix.aggs(parameters, actualField, predictedField);
|
||||||
aggs.addAll(matrixAggs.v1());
|
aggs.addAll(matrixAggs.v1());
|
||||||
pipelineAggs.addAll(matrixAggs.v2());
|
pipelineAggs.addAll(matrixAggs.v2());
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -73,9 +74,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
||||||
|
static final String STEP_1_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_cardinality_of_actual_class";
|
||||||
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
||||||
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
||||||
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
|
||||||
private static final String OTHER_BUCKET_KEY = "_other_";
|
private static final String OTHER_BUCKET_KEY = "_other_";
|
||||||
private static final String DEFAULT_AGG_NAME_PREFIX = "";
|
private static final String DEFAULT_AGG_NAME_PREFIX = "";
|
||||||
private static final int DEFAULT_SIZE = 10;
|
private static final int DEFAULT_SIZE = 10;
|
||||||
|
@ -84,6 +85,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
private final int size;
|
private final int size;
|
||||||
private final String aggNamePrefix;
|
private final String aggNamePrefix;
|
||||||
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
||||||
|
private final SetOnce<Long> actualClassesCardinality = new SetOnce<>();
|
||||||
|
/** Accumulates actual classes processed so far. It may take more than 1 call to #process method to fill this field completely. */
|
||||||
|
private final List<ActualClass> actualClasses = new ArrayList<>();
|
||||||
private final SetOnce<Result> result = new SetOnce<>();
|
private final SetOnce<Result> result = new SetOnce<>();
|
||||||
|
|
||||||
public MulticlassConfusionMatrix() {
|
public MulticlassConfusionMatrix() {
|
||||||
|
@ -122,34 +126,45 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
if (topActualClassNames.get() == null) { // This is step 1
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
|
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
|
||||||
return Tuple.tuple(
|
return Tuple.tuple(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
|
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
|
||||||
.field(actualField)
|
.field(actualField)
|
||||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||||
.size(size)),
|
.size(size),
|
||||||
|
AggregationBuilders.cardinality(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS))
|
||||||
|
.field(actualField)),
|
||||||
Collections.emptyList());
|
Collections.emptyList());
|
||||||
}
|
}
|
||||||
if (result.get() == null) { // This is step 2
|
if (result.get() == null) { // These are steps 2, 3, 4 etc.
|
||||||
KeyedFilter[] keyedFiltersActual =
|
|
||||||
topActualClassNames.get().stream()
|
|
||||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
|
||||||
.toArray(KeyedFilter[]::new);
|
|
||||||
KeyedFilter[] keyedFiltersPredicted =
|
KeyedFilter[] keyedFiltersPredicted =
|
||||||
topActualClassNames.get().stream()
|
topActualClassNames.get().stream()
|
||||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||||
.toArray(KeyedFilter[]::new);
|
.toArray(KeyedFilter[]::new);
|
||||||
return Tuple.tuple(
|
// Knowing exactly how many buckets does each aggregation use, we can choose the size of the batch so that
|
||||||
Arrays.asList(
|
// too_many_buckets_exception exception is not thrown.
|
||||||
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
|
// The only exception is when "search.max_buckets" is set far too low to even have 1 actual class in the batch.
|
||||||
.field(actualField),
|
// In such case, the exception will be thrown telling the user they should increase the value of "search.max_buckets".
|
||||||
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
|
int actualClassesPerBatch = Math.max(parameters.getMaxBuckets() / (topActualClassNames.get().size() + 2), 1);
|
||||||
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
|
KeyedFilter[] keyedFiltersActual =
|
||||||
.otherBucket(true)
|
topActualClassNames.get().stream()
|
||||||
.otherBucketKey(OTHER_BUCKET_KEY))),
|
.skip(actualClasses.size())
|
||||||
Collections.emptyList());
|
.limit(actualClassesPerBatch)
|
||||||
|
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||||
|
.toArray(KeyedFilter[]::new);
|
||||||
|
if (keyedFiltersActual.length > 0) {
|
||||||
|
return Tuple.tuple(
|
||||||
|
Arrays.asList(
|
||||||
|
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
|
||||||
|
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
|
||||||
|
.otherBucket(true)
|
||||||
|
.otherBucketKey(OTHER_BUCKET_KEY))),
|
||||||
|
Collections.emptyList());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
@ -160,10 +175,12 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
|
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
|
||||||
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
|
if (actualClassesCardinality.get() == null && aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS)) != null) {
|
||||||
|
Cardinality cardinalityAgg = aggs.get(aggName(STEP_1_CARDINALITY_OF_ACTUAL_CLASS));
|
||||||
|
actualClassesCardinality.set(cardinalityAgg.getValue());
|
||||||
|
}
|
||||||
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
||||||
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
|
|
||||||
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
|
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
|
|
||||||
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
||||||
String actualClass = bucket.getKeyAsString();
|
String actualClass = bucket.getKeyAsString();
|
||||||
long actualClassDocCount = bucket.getDocCount();
|
long actualClassDocCount = bucket.getDocCount();
|
||||||
|
@ -182,7 +199,9 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
||||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
||||||
}
|
}
|
||||||
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
|
if (actualClasses.size() == topActualClassNames.get().size()) {
|
||||||
|
result.set(new Result(actualClasses, Math.max(actualClassesCardinality.get() - size, 0)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -97,7 +98,9 @@ public class Precision implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||||
this.actualField.trySet(actualField);
|
this.actualField.trySet(actualField);
|
||||||
if (topActualClassNames.get() == null) { // This is step 1
|
if (topActualClassNames.get() == null) { // This is step 1
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -90,7 +91,9 @@ public class Recall implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||||
this.actualField.trySet(actualField);
|
this.actualField.trySet(actualField);
|
||||||
if (result.get() != null) {
|
if (result.get() != null) {
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.text.MessageFormat;
|
import java.text.MessageFormat;
|
||||||
|
@ -67,7 +68,9 @@ public class MeanSquaredError implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBui
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.text.MessageFormat;
|
import java.text.MessageFormat;
|
||||||
|
@ -72,7 +73,9 @@ public class RSquared implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -66,7 +67,9 @@ abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
|
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedProbabilityField) {
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||||
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -127,7 +128,9 @@ public class AucRoc implements EvaluationMetric {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
|
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedProbabilityField) {
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||||
|
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
public class EvaluationParametersTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testConstructorAndGetters() {
|
||||||
|
EvaluationParameters params = new EvaluationParameters(17);
|
||||||
|
assertThat(params.getMaxBuckets(), equalTo(17));
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||||
|
|
||||||
|
@ -17,19 +18,21 @@ import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
|
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
|
||||||
return Accuracy.fromXContent(parser);
|
return Accuracy.fromXContent(parser);
|
||||||
|
@ -62,6 +65,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
100L),
|
100L),
|
||||||
|
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
|
||||||
mockFilters(
|
mockFilters(
|
||||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -79,13 +83,12 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
|
|
||||||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||||
|
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
accuracy.process(aggs);
|
accuracy.process(aggs);
|
||||||
|
|
||||||
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||||
|
|
||||||
Result result = accuracy.getResult().get();
|
Result result = accuracy.getResult().get();
|
||||||
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||||
|
@ -106,6 +109,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
100L),
|
100L),
|
||||||
|
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
|
||||||
mockFilters(
|
mockFilters(
|
||||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -123,11 +127,10 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
|
|
||||||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||||
|
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
accuracy.aggs("foo", "bar");
|
accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
|
||||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -43,6 +44,8 @@ import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||||
|
@ -100,7 +103,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||||
|
|
||||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
|
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
|
||||||
|
|
||||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||||
}
|
}
|
||||||
|
@ -196,7 +199,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||||
|
String actualField,
|
||||||
|
String predictedField) {
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
|
||||||
|
@ -23,18 +24,20 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||||
|
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.not;
|
import static org.hamcrest.Matchers.not;
|
||||||
|
|
||||||
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
|
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
|
||||||
return MulticlassConfusionMatrix.fromXContent(parser);
|
return MulticlassConfusionMatrix.fromXContent(parser);
|
||||||
|
@ -80,12 +83,12 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
|
|
||||||
public void testAggs() {
|
public void testAggs() {
|
||||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
||||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs("act", "pred");
|
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred");
|
||||||
assertThat(aggs, isTuple(not(empty()), empty()));
|
assertThat(aggs, isTuple(not(empty()), empty()));
|
||||||
assertThat(confusionMatrix.getResult(), isEmpty());
|
assertThat(confusionMatrix.getResult(), isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testProcess() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockTerms(
|
mockTerms(
|
||||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
|
@ -93,6 +96,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
0L),
|
0L),
|
||||||
|
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L),
|
||||||
mockFilters(
|
mockFilters(
|
||||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -109,13 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
new Aggregations(Arrays.asList(mockFilters(
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))))
|
||||||
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
|
));
|
||||||
|
|
||||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||||
confusionMatrix.process(aggs);
|
confusionMatrix.process(aggs);
|
||||||
|
|
||||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||||
Result result = confusionMatrix.getResult().get();
|
Result result = confusionMatrix.getResult().get();
|
||||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -127,7 +131,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
assertThat(result.getOtherActualClassCount(), equalTo(0L));
|
assertThat(result.getOtherActualClassCount(), equalTo(0L));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
public void testProcess_OtherClassesCountGreaterThanZero() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockTerms(
|
mockTerms(
|
||||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
|
@ -135,6 +139,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
100L),
|
100L),
|
||||||
|
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 5L),
|
||||||
mockFilters(
|
mockFilters(
|
||||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -151,13 +156,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
new Aggregations(Arrays.asList(mockFilters(
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))))
|
||||||
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
|
));
|
||||||
|
|
||||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||||
confusionMatrix.process(aggs);
|
confusionMatrix.process(aggs);
|
||||||
|
|
||||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||||
Result result = confusionMatrix.getResult().get();
|
Result result = confusionMatrix.getResult().get();
|
||||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
|
@ -168,4 +173,106 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
|
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
|
||||||
assertThat(result.getOtherActualClassCount(), equalTo(3L));
|
assertThat(result.getOtherActualClassCount(), equalTo(3L));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testProcess_MoreThanTwoStepsNeeded() {
|
||||||
|
Aggregations aggsStep1 = new Aggregations(Arrays.asList(
|
||||||
|
mockTerms(
|
||||||
|
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockTermsBucket("ant", new Aggregations(Arrays.asList())),
|
||||||
|
mockTermsBucket("cat", new Aggregations(Arrays.asList())),
|
||||||
|
mockTermsBucket("dog", new Aggregations(Arrays.asList())),
|
||||||
|
mockTermsBucket("fox", new Aggregations(Arrays.asList()))),
|
||||||
|
0L),
|
||||||
|
mockCardinality(MulticlassConfusionMatrix.STEP_1_CARDINALITY_OF_ACTUAL_CLASS, 2L)
|
||||||
|
));
|
||||||
|
Aggregations aggsStep2 = new Aggregations(Arrays.asList(
|
||||||
|
mockFilters(
|
||||||
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket(
|
||||||
|
"ant",
|
||||||
|
46,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("ant", 10L),
|
||||||
|
mockFiltersBucket("cat", 11L),
|
||||||
|
mockFiltersBucket("dog", 12L),
|
||||||
|
mockFiltersBucket("fox", 13L),
|
||||||
|
mockFiltersBucket("_other_", 0L)))))),
|
||||||
|
mockFiltersBucket(
|
||||||
|
"cat",
|
||||||
|
86,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("ant", 20L),
|
||||||
|
mockFiltersBucket("cat", 21L),
|
||||||
|
mockFiltersBucket("dog", 22L),
|
||||||
|
mockFiltersBucket("fox", 23L),
|
||||||
|
mockFiltersBucket("_other_", 0L))))))))
|
||||||
|
));
|
||||||
|
Aggregations aggsStep3 = new Aggregations(Arrays.asList(
|
||||||
|
mockFilters(
|
||||||
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket(
|
||||||
|
"dog",
|
||||||
|
126,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("ant", 30L),
|
||||||
|
mockFiltersBucket("cat", 31L),
|
||||||
|
mockFiltersBucket("dog", 32L),
|
||||||
|
mockFiltersBucket("fox", 33L),
|
||||||
|
mockFiltersBucket("_other_", 0L)))))),
|
||||||
|
mockFiltersBucket(
|
||||||
|
"fox",
|
||||||
|
166,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("ant", 40L),
|
||||||
|
mockFiltersBucket("cat", 41L),
|
||||||
|
mockFiltersBucket("dog", 42L),
|
||||||
|
mockFiltersBucket("fox", 43L),
|
||||||
|
mockFiltersBucket("_other_", 0L))))))))
|
||||||
|
));
|
||||||
|
|
||||||
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(4, null);
|
||||||
|
confusionMatrix.process(aggsStep1);
|
||||||
|
confusionMatrix.process(aggsStep2);
|
||||||
|
confusionMatrix.process(aggsStep3);
|
||||||
|
|
||||||
|
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||||
|
Result result = confusionMatrix.getResult().get();
|
||||||
|
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||||
|
assertThat(
|
||||||
|
result.getConfusionMatrix(),
|
||||||
|
equalTo(
|
||||||
|
Arrays.asList(
|
||||||
|
new ActualClass("ant", 46, Arrays.asList(
|
||||||
|
new PredictedClass("ant", 10L),
|
||||||
|
new PredictedClass("cat", 11L),
|
||||||
|
new PredictedClass("dog", 12L),
|
||||||
|
new PredictedClass("fox", 13L)), 0),
|
||||||
|
new ActualClass("cat", 86, Arrays.asList(
|
||||||
|
new PredictedClass("ant", 20L),
|
||||||
|
new PredictedClass("cat", 21L),
|
||||||
|
new PredictedClass("dog", 22L),
|
||||||
|
new PredictedClass("fox", 23L)), 0),
|
||||||
|
new ActualClass("dog", 126, Arrays.asList(
|
||||||
|
new PredictedClass("ant", 30L),
|
||||||
|
new PredictedClass("cat", 31L),
|
||||||
|
new PredictedClass("dog", 32L),
|
||||||
|
new PredictedClass("fox", 33L)), 0),
|
||||||
|
new ActualClass("fox", 166, Arrays.asList(
|
||||||
|
new PredictedClass("ant", 40L),
|
||||||
|
new PredictedClass("cat", 41L),
|
||||||
|
new PredictedClass("dog", 42L),
|
||||||
|
new PredictedClass("fox", 43L)), 0))));
|
||||||
|
assertThat(result.getOtherActualClassCount(), equalTo(0L));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,22 +10,25 @@ import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||||
|
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Precision doParseInstance(XContentParser parser) throws IOException {
|
protected Precision doParseInstance(XContentParser parser) throws IOException {
|
||||||
return Precision.fromXContent(parser);
|
return Precision.fromXContent(parser);
|
||||||
|
@ -61,7 +64,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
Precision precision = new Precision();
|
Precision precision = new Precision();
|
||||||
precision.process(aggs);
|
precision.process(aggs);
|
||||||
|
|
||||||
assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||||
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
|
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +114,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||||
Aggregations aggs =
|
Aggregations aggs =
|
||||||
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
|
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
|
||||||
Precision precision = new Precision();
|
Precision precision = new Precision();
|
||||||
precision.aggs("foo", "bar");
|
precision.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
|
||||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,21 +10,24 @@ import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||||
|
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
import static org.elasticsearch.test.hamcrest.TupleMatchers.isTuple;
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Recall doParseInstance(XContentParser parser) throws IOException {
|
protected Recall doParseInstance(XContentParser parser) throws IOException {
|
||||||
return Recall.fromXContent(parser);
|
return Recall.fromXContent(parser);
|
||||||
|
@ -59,7 +62,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
Recall recall = new Recall();
|
Recall recall = new Recall();
|
||||||
recall.process(aggs);
|
recall.process(aggs);
|
||||||
|
|
||||||
assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||||
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
|
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +113,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
|
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
|
||||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
|
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
|
||||||
Recall recall = new Recall();
|
Recall recall = new Recall();
|
||||||
recall.aggs("foo", "bar");
|
recall.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
|
||||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -28,6 +29,8 @@ import static org.hamcrest.Matchers.greaterThan;
|
||||||
|
|
||||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||||
|
@ -85,7 +88,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
|
|
||||||
Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
|
Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
|
||||||
|
|
||||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -28,6 +29,8 @@ import static org.hamcrest.Matchers.greaterThan;
|
||||||
|
|
||||||
public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
|
public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
|
||||||
|
|
||||||
|
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||||
|
@ -98,7 +101,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
||||||
|
|
||||||
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
|
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
|
||||||
|
|
||||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,18 +5,21 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||||
import org.elasticsearch.action.bulk.BulkResponse;
|
import org.elasticsearch.action.bulk.BulkResponse;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.search.aggregations.MultiBucketConsumerService.TooManyBucketsException;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
|
@ -28,6 +31,8 @@ import static org.hamcrest.Matchers.contains;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
|
@ -49,6 +54,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
@After
|
@After
|
||||||
public void cleanup() {
|
public void cleanup() {
|
||||||
cleanUp();
|
cleanUp();
|
||||||
|
client().admin().cluster()
|
||||||
|
.prepareUpdateSettings()
|
||||||
|
.setTransientSettings(Settings.builder().putNull("search.max_buckets"))
|
||||||
|
.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate_DefaultMetrics() {
|
public void testEvaluate_DefaultMetrics() {
|
||||||
|
@ -208,7 +217,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
private void evaluateWithMulticlassConfusionMatrix() {
|
||||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||||
evaluateDataFrame(
|
evaluateDataFrame(
|
||||||
ANIMALS_DATA_INDEX,
|
ANIMALS_DATA_INDEX,
|
||||||
|
@ -271,6 +280,23 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
||||||
|
evaluateWithMulticlassConfusionMatrix();
|
||||||
|
|
||||||
|
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 20)).get();
|
||||||
|
evaluateWithMulticlassConfusionMatrix();
|
||||||
|
|
||||||
|
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 7)).get();
|
||||||
|
evaluateWithMulticlassConfusionMatrix();
|
||||||
|
|
||||||
|
client().admin().cluster().prepareUpdateSettings().setTransientSettings(Settings.builder().put("search.max_buckets", 6)).get();
|
||||||
|
ElasticsearchException e = expectThrows(ElasticsearchException.class, this::evaluateWithMulticlassConfusionMatrix);
|
||||||
|
|
||||||
|
assertThat(e.getCause(), is(instanceOf(TooManyBucketsException.class)));
|
||||||
|
TooManyBucketsException tmbe = (TooManyBucketsException) e.getCause();
|
||||||
|
assertThat(tmbe.getMaxBuckets(), equalTo(6));
|
||||||
|
}
|
||||||
|
|
||||||
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
|
public void testEvaluate_ConfusionMatrixMetricWithUserProvidedSize() {
|
||||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||||
evaluateDataFrame(
|
evaluateDataFrame(
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.action.search.SearchRequest;
|
||||||
import org.elasticsearch.action.support.ActionFilters;
|
import org.elasticsearch.action.support.ActionFilters;
|
||||||
import org.elasticsearch.action.support.HandledTransportAction;
|
import org.elasticsearch.action.support.HandledTransportAction;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
|
import org.elasticsearch.cluster.service.ClusterService;
|
||||||
import org.elasticsearch.common.inject.Inject;
|
import org.elasticsearch.common.inject.Inject;
|
||||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.tasks.Task;
|
import org.elasticsearch.tasks.Task;
|
||||||
|
@ -18,26 +19,41 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
import static org.elasticsearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING;
|
||||||
|
|
||||||
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
|
public class TransportEvaluateDataFrameAction extends HandledTransportAction<EvaluateDataFrameAction.Request,
|
||||||
EvaluateDataFrameAction.Response> {
|
EvaluateDataFrameAction.Response> {
|
||||||
|
|
||||||
private final ThreadPool threadPool;
|
private final ThreadPool threadPool;
|
||||||
private final Client client;
|
private final Client client;
|
||||||
|
private final AtomicReference<Integer> maxBuckets = new AtomicReference<>();
|
||||||
|
|
||||||
@Inject
|
@Inject
|
||||||
public TransportEvaluateDataFrameAction(TransportService transportService, ActionFilters actionFilters, ThreadPool threadPool,
|
public TransportEvaluateDataFrameAction(TransportService transportService,
|
||||||
Client client) {
|
ActionFilters actionFilters,
|
||||||
|
ThreadPool threadPool,
|
||||||
|
Client client,
|
||||||
|
ClusterService clusterService) {
|
||||||
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
|
super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new);
|
||||||
this.threadPool = threadPool;
|
this.threadPool = threadPool;
|
||||||
this.client = client;
|
this.client = client;
|
||||||
|
this.maxBuckets.set(MAX_BUCKET_SETTING.get(clusterService.getSettings()));
|
||||||
|
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BUCKET_SETTING, this::setMaxBuckets);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setMaxBuckets(int maxBuckets) {
|
||||||
|
this.maxBuckets.set(maxBuckets);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doExecute(Task task, EvaluateDataFrameAction.Request request,
|
protected void doExecute(Task task,
|
||||||
|
EvaluateDataFrameAction.Request request,
|
||||||
ActionListener<EvaluateDataFrameAction.Response> listener) {
|
ActionListener<EvaluateDataFrameAction.Response> listener) {
|
||||||
ActionListener<List<Void>> resultsListener = ActionListener.wrap(
|
ActionListener<List<Void>> resultsListener = ActionListener.wrap(
|
||||||
unused -> {
|
unused -> {
|
||||||
|
@ -48,7 +64,9 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
);
|
);
|
||||||
|
|
||||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request);
|
// Create an immutable collection of parameters to be used by evaluation metrics.
|
||||||
|
EvaluationParameters parameters = new EvaluationParameters(maxBuckets.get());
|
||||||
|
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, parameters, request);
|
||||||
evaluationExecutor.execute(resultsListener);
|
evaluationExecutor.execute(resultsListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,12 +86,14 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
||||||
private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
|
private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
|
||||||
|
|
||||||
private final Client client;
|
private final Client client;
|
||||||
|
private final EvaluationParameters parameters;
|
||||||
private final EvaluateDataFrameAction.Request request;
|
private final EvaluateDataFrameAction.Request request;
|
||||||
private final Evaluation evaluation;
|
private final Evaluation evaluation;
|
||||||
|
|
||||||
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) {
|
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluationParameters parameters, EvaluateDataFrameAction.Request request) {
|
||||||
super(threadPool.generic(), unused -> true, unused -> true);
|
super(threadPool.generic(), unused -> true, unused -> true);
|
||||||
this.client = client;
|
this.client = client;
|
||||||
|
this.parameters = parameters;
|
||||||
this.request = request;
|
this.request = request;
|
||||||
this.evaluation = request.getEvaluation();
|
this.evaluation = request.getEvaluation();
|
||||||
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
||||||
|
@ -82,7 +102,7 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
||||||
|
|
||||||
private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
|
private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
|
||||||
return listener -> {
|
return listener -> {
|
||||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery());
|
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(parameters, request.getParsedQuery());
|
||||||
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
||||||
client.execute(
|
client.execute(
|
||||||
SearchAction.INSTANCE,
|
SearchAction.INSTANCE,
|
||||||
|
|
Loading…
Reference in New Issue