[7.x] Allow evaluation to consist of multiple steps. (#46653) (#47194)

This commit is contained in:
Przemysław Witek 2019-09-27 13:01:51 +02:00 committed by GitHub
parent a1e2e208ce
commit 3fbd58d156
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 342 additions and 176 deletions

View File

@ -105,28 +105,31 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
return indices;
}
public final void setIndices(List<String> indices) {
public final Request setIndices(List<String> indices) {
ExceptionsHelper.requireNonNull(indices, INDEX);
if (indices.isEmpty()) {
throw ExceptionsHelper.badRequestException("At least one index must be specified");
}
this.indices = indices.toArray(new String[indices.size()]);
return this;
}
public QueryBuilder getParsedQuery() {
return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
}
public final void setQueryProvider(QueryProvider queryProvider) {
public final Request setQueryProvider(QueryProvider queryProvider) {
this.queryProvider = queryProvider;
return this;
}
public Evaluation getEvaluation() {
return evaluation;
}
public final void setEvaluation(Evaluation evaluation) {
public final Request setEvaluation(Evaluation evaluation) {
this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION);
return this;
}
@Override
@ -203,6 +206,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
this.metrics = Objects.requireNonNull(metrics);
}
public String getEvaluationName() {
return evaluationName;
}
public List<EvaluationMetricResult> getMetrics() {
return metrics;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(evaluationName);
@ -214,7 +225,7 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
builder.startObject();
builder.startObject(evaluationName);
for (EvaluationMetricResult metric : metrics) {
builder.field(metric.getName(), metric);
builder.field(metric.getMetricName(), metric);
}
builder.endObject();
builder.endObject();

View File

@ -5,14 +5,17 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Defines an evaluation
@ -25,15 +28,53 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
String getName();
/**
* Builds the search required to collect data to compute the evaluation result
* @param queryBuilder User-provided query that must be respected when collecting data
* Returns the list of metrics to evaluate
* @return list of metrics to evaluate
*/
SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
List<? extends EvaluationMetric> getMetrics();
/**
* Computes the evaluation result
* @param searchResponse The search response required to compute the result
* @param listener A listener of the results
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener);
SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);
/**
* Builds the search that verifies existence of required fields and applies user-provided query
* @param requiredFields fields that must exist
* @param userProvidedQueryBuilder user-provided query
*/
default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
for (String requiredField : requiredFields) {
boolQuery.filter(QueryBuilders.existsQuery(requiredField));
}
boolQuery.filter(userProvidedQueryBuilder);
return new SearchSourceBuilder().size(0).query(boolQuery);
}
/**
* Processes {@link SearchResponse} from the search action
* @param searchResponse response from the search action
*/
void process(SearchResponse searchResponse);
/**
* @return true iff all the metrics have their results computed
*/
default boolean hasAllResults() {
return getMetrics().stream().map(EvaluationMetric::getResult).allMatch(Optional::isPresent);
}
/**
* Returns the list of evaluation results
* @return list of evaluation results
*/
default List<EvaluationMetricResult> getResults() {
return getMetrics().stream()
.map(EvaluationMetric::getResult)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import java.util.Optional;
/**
* {@link EvaluationMetric} class represents a metric to evaluate.
*/
public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
/**
* Returns the name of the metric (which may differ to the writeable name)
*/
String getName();
/**
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
*/
Optional<EvaluationMetricResult> getResult();
}

View File

@ -14,7 +14,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable {
/**
* Returns the name of the metric
* Returns the name of the metric (which may differ to the writeable name)
*/
String getName();
String getMetricName();
}

View File

@ -20,10 +20,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
/**
* Calculates the mean squared error between two known numerical fields.
@ -48,28 +50,34 @@ public class MeanSquaredError implements RegressionMetric {
return PARSER.apply(parser, null);
}
public MeanSquaredError(StreamInput in) {
private EvaluationMetricResult result;
}
public MeanSquaredError(StreamInput in) {}
public MeanSquaredError() {
}
public MeanSquaredError() {}
@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
if (result != null) {
return Collections.emptyList();
}
return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
}
@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
public void process(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
return value == null ? new Result(0.0) : new Result(value.value());
result = value == null ? new Result(0.0) : new Result(value.value());
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
@ -121,7 +129,7 @@ public class MeanSquaredError implements RegressionMetric {
}
@Override
public String getName() {
public String getMetricName() {
return NAME.getPreferredName();
}

View File

@ -23,9 +23,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
/**
* Calculates R-Squared between two known numerical fields.
@ -53,36 +55,42 @@ public class RSquared implements RegressionMetric {
return PARSER.apply(parser, null);
}
public RSquared(StreamInput in) {
private EvaluationMetricResult result;
}
public RSquared(StreamInput in) {}
public RSquared() {
}
public RSquared() {}
@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
if (result != null) {
return Collections.emptyList();
}
return Arrays.asList(
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
}
@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
public void process(Aggregations aggs) {
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
new Result(0.0) :
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -132,7 +140,7 @@ public class RSquared implements RegressionMetric {
}
@Override
public String getName() {
public String getMetricName() {
return NAME.getPreferredName();
}

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
@ -14,17 +13,15 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
@ -86,19 +83,16 @@ public class Regression implements Evaluation {
}
private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName));
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
return metrics;
}
private static List<RegressionMetric> defaultMetrics() {
List<RegressionMetric> defaultMetrics = new ArrayList<>(2);
defaultMetrics.add(new MeanSquaredError());
defaultMetrics.add(new RSquared());
return defaultMetrics;
return Arrays.asList(new MeanSquaredError(), new RSquared());
}
@Override
@ -107,12 +101,15 @@ public class Regression implements Evaluation {
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery(actualField))
.filter(QueryBuilders.existsQuery(predictedField))
.filter(queryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
public List<RegressionMetric> getMetrics() {
return metrics;
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder =
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
for (RegressionMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
@ -121,18 +118,14 @@ public class Regression implements Evaluation {
}
@Override
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
actualField,
predictedField));
return;
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
}
for (RegressionMetric metric : metrics) {
results.add(metric.evaluate(searchResponse.getAggregations()));
metric.process(searchResponse.getAggregations());
}
listener.onResponse(results);
}
@Override

View File

@ -5,20 +5,14 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import java.util.List;
public interface RegressionMetric extends ToXContentObject, NamedWriteable {
/**
* Returns the name of the metric (which may differ to the writeable name)
*/
String getMetricName();
public interface RegressionMetric extends EvaluationMetric {
/**
* Builds the aggregation that collect required data to compute the metric
@ -29,9 +23,8 @@ public interface RegressionMetric extends ToXContentObject, NamedWriteable {
List<AggregationBuilder> aggs(String actualField, String predictedField);
/**
* Calculates the metric result
* @param aggs the aggregations
* @return the metric result
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
EvaluationMetricResult evaluate(Aggregations aggs);
void process(Aggregations aggs);
}

View File

@ -13,27 +13,31 @@ import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
public static final ParseField AT = new ParseField("at");
protected final double[] thresholds;
private EvaluationMetricResult result;
protected AbstractConfusionMatrixMetric(double[] thresholds) {
this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
if (thresholds.length == 0) {
throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName()
+ "] must have at least one value");
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
}
for (double threshold : thresholds) {
if (threshold < 0 || threshold > 1.0) {
throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName()
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName()
+ "] values must be in [0.0, 1.0]");
}
}
@ -58,6 +62,9 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
@Override
public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
if (result != null) {
return Collections.emptyList();
}
List<AggregationBuilder> aggs = new ArrayList<>();
for (double threshold : thresholds) {
aggs.addAll(aggsAt(actualField, classInfos, threshold));
@ -65,14 +72,26 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
return aggs;
}
@Override
public void process(ClassInfo classInfo, Aggregations aggs) {
result = evaluate(classInfo, aggs);
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
protected enum Condition {
TP, FP, TN, FN;
}
protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
}
protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {

View File

@ -30,6 +30,7 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
/**
@ -70,6 +71,7 @@ public class AucRoc implements SoftClassificationMetric {
}
private final boolean includeCurve;
private EvaluationMetricResult result;
public AucRoc(Boolean includeCurve) {
this.includeCurve = includeCurve == null ? false : includeCurve;
@ -98,7 +100,7 @@ public class AucRoc implements SoftClassificationMetric {
}
@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}
@ -117,6 +119,9 @@ public class AucRoc implements SoftClassificationMetric {
@Override
public List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
if (result != null) {
return Collections.emptyList();
}
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
List<AggregationBuilder> aggs = new ArrayList<>();
for (ClassInfo classInfo : classInfos) {
@ -134,22 +139,31 @@ public class AucRoc implements SoftClassificationMetric {
return aggs;
}
private String evaluatedLabelAggName(ClassInfo classInfo) {
return getMetricName() + "_" + classInfo.getName();
}
private String restLabelsAggName(ClassInfo classInfo) {
return getMetricName() + "_non_" + classInfo.getName();
@Override
public void process(ClassInfo classInfo, Aggregations aggs) {
result = evaluate(classInfo, aggs);
}
@Override
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
private String evaluatedLabelAggName(ClassInfo classInfo) {
return getName() + "_" + classInfo.getName();
}
private String restLabelsAggName(ClassInfo classInfo) {
return getName() + "_non_" + classInfo.getName();
}
private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
Filter restAgg = aggs.get(restLabelsAggName(classInfo));
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
"[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
"[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
"[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
"[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
@ -326,7 +340,7 @@ public class AucRoc implements SoftClassificationMetric {
}
@Override
public String getName() {
public String getMetricName() {
return NAME.getPreferredName();
}

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
@ -14,18 +13,14 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
@ -87,17 +82,16 @@ public class BinarySoftClassification implements Evaluation {
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName));
Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
return metrics;
}
private static List<SoftClassificationMetric> defaultMetrics() {
List<SoftClassificationMetric> defaultMetrics = new ArrayList<>(4);
defaultMetrics.add(new AucRoc(false));
defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75)));
defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75)));
defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
return defaultMetrics;
return Arrays.asList(
new AucRoc(false),
new Precision(Arrays.asList(0.25, 0.5, 0.75)),
new Recall(Arrays.asList(0.25, 0.5, 0.75)),
new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
}
public BinarySoftClassification(StreamInput in) throws IOException {
@ -126,7 +120,7 @@ public class BinarySoftClassification implements Evaluation {
builder.startObject(METRICS.getPreferredName());
for (SoftClassificationMetric metric : metrics) {
builder.field(metric.getMetricName(), metric);
builder.field(metric.getName(), metric);
}
builder.endObject();
@ -155,34 +149,34 @@ public class BinarySoftClassification implements Evaluation {
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery(actualField))
.filter(QueryBuilders.existsQuery(predictedProbabilityField))
.filter(queryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
public List<SoftClassificationMetric> getMetrics() {
return metrics;
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder =
newSearchSourceBuilder(Arrays.asList(actualField, predictedProbabilityField), userProvidedQueryBuilder);
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
for (SoftClassificationMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}
@Override
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField,
predictedProbabilityField));
return;
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
}
List<EvaluationMetricResult> results = new ArrayList<>();
Aggregations aggs = searchResponse.getAggregations();
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
for (SoftClassificationMetric metric : metrics) {
results.add(metric.evaluate(binaryClassInfo, aggs));
metric.process(binaryClassInfo, searchResponse.getAggregations());
}
listener.onResponse(results);
}
private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {

View File

@ -50,7 +50,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
}
@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}
@ -132,7 +132,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
}
@Override
public String getName() {
public String getMetricName() {
return NAME.getPreferredName();
}

View File

@ -48,7 +48,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
}
@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}

View File

@ -48,7 +48,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
}
@Override
public String getMetricName() {
public String getName() {
return NAME.getPreferredName();
}
@ -68,7 +68,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
@Override
protected List<AggregationBuilder> aggsAt(String actualField, List<ClassInfo> classInfos, double threshold) {
List<AggregationBuilder> aggs = new ArrayList<>();
for (ClassInfo classInfo: classInfos) {
for (ClassInfo classInfo : classInfos) {
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
aggs.add(buildAgg(classInfo, threshold, Condition.FN));
}

View File

@ -40,7 +40,7 @@ public class ScoreByThresholdResult implements EvaluationMetricResult {
}
@Override
public String getName() {
public String getMetricName() {
return name;
}

View File

@ -5,16 +5,15 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import java.util.List;
public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable {
public interface SoftClassificationMetric extends EvaluationMetric {
/**
* The information of a specific class
@ -37,11 +36,6 @@ public interface SoftClassificationMetric extends ToXContentObject, NamedWriteab
String getProbabilityField();
}
/**
* Returns the name of the metric (which may differ to the writeable name)
*/
String getMetricName();
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual class
@ -51,10 +45,9 @@ public interface SoftClassificationMetric extends ToXContentObject, NamedWriteab
List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos);
/**
* Calculates the metric result for a given class
* Processes given aggregations as a step towards computing result
* @param classInfo the class to calculate the metric for
* @param aggs the aggregations
* @return the metric result
* @param aggs aggregations from {@link SearchResponse}
*/
EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
void process(ClassInfo classInfo, Aggregations aggs);
}

View File

@ -49,8 +49,9 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
));
MeanSquaredError mse = new MeanSquaredError();
EvaluationMetricResult result = mse.evaluate(aggs);
mse.process(aggs);
EvaluationMetricResult result = mse.getResult().get();
String expected = "{\"error\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}
@ -61,7 +62,9 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
));
MeanSquaredError mse = new MeanSquaredError();
EvaluationMetricResult result = mse.evaluate(aggs);
mse.process(aggs);
EvaluationMetricResult result = mse.getResult().get();
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
}

View File

@ -52,8 +52,9 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
));
RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
rSquared.process(aggs);
EvaluationMetricResult result = rSquared.getResult().get();
String expected = "{\"value\":0.9348643947690524}";
assertThat(Strings.toString(result), equalTo(expected));
}
@ -67,35 +68,48 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
));
RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
rSquared.process(aggs);
EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
public void testEvaluate_GivenMissingAggs() {
EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
assertThat(result, equalTo(zeroResult));
rSquared.process(aggs);
aggs = new Aggregations(Arrays.asList(
EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
public void testEvaluate_GivenMissingExtendedStatsAgg() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
createSingleMetricAgg("residual_sum_of_squares", 0.2377)
));
result = rSquared.evaluate(aggs);
assertThat(result, equalTo(zeroResult));
RSquared rSquared = new RSquared();
rSquared.process(aggs);
aggs = new Aggregations(Arrays.asList(
EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
createExtendedStatsAgg("extended_stats_actual",100, 50)
));
result = rSquared.evaluate(aggs);
assertThat(result, equalTo(zeroResult));
RSquared rSquared = new RSquared();
rSquared.process(aggs);
EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
@ -22,6 +23,7 @@ import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@ -43,13 +45,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
if (randomBoolean()) {
metrics.add(RSquaredTests.createRandom());
}
return new Regression(randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomBoolean() ?
null :
metrics.isEmpty() ?
null :
metrics);
return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}
@Override
@ -74,7 +70,6 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
}
public void testBuildSearch() {
Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError()));
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
@ -82,10 +77,15 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(QueryBuilders.existsQuery("prob"))
.filter(QueryBuilders.existsQuery("pred"))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
@ -22,6 +23,7 @@ import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
@ -81,7 +83,6 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
}
public void testBuildSearch() {
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
@ -93,6 +94,11 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
}
}

View File

@ -12,12 +12,13 @@ import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
import java.util.List;
@ -38,24 +39,64 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
@Override
protected void doExecute(Task task, EvaluateDataFrameAction.Request request,
ActionListener<EvaluateDataFrameAction.Response> listener) {
Evaluation evaluation = request.getEvaluation();
SearchRequest searchRequest = new SearchRequest(request.getIndices());
searchRequest.source(evaluation.buildSearch(request.getParsedQuery()));
ActionListener<List<EvaluationMetricResult>> resultsListener = ActionListener.wrap(
results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)),
ActionListener<List<Void>> resultsListener = ActionListener.wrap(
unused -> {
EvaluateDataFrameAction.Response response =
new EvaluateDataFrameAction.Response(request.getEvaluation().getName(), request.getEvaluation().getResults());
listener.onResponse(response);
},
listener::onFailure
);
client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
searchResponse -> threadPool.generic().execute(() -> {
try {
evaluation.evaluate(searchResponse, resultsListener);
} catch (Exception e) {
listener.onFailure(e);
};
}),
listener::onFailure
));
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request);
evaluationExecutor.execute(resultsListener);
}
/**
* {@link EvaluationExecutor} class allows for serial execution of evaluation steps.
*
* Each step consists of the following phases:
* 1. build search request with aggs requested by individual metrics
* 2. execute search action with the request built in (1.)
* 3. make all individual metrics process the search response obtained in (2.)
* 4. check if all the metrics have their results computed
* a) If so, call the final listener and finish
* b) Otherwise, add another step to the queue
*
* To avoid infinite loop it is essential that every metric *does* compute its result at some point.
* */
private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
private final Client client;
private final EvaluateDataFrameAction.Request request;
private final Evaluation evaluation;
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) {
super(threadPool.generic(), unused -> true, unused -> true);
this.client = client;
this.request = request;
this.evaluation = request.getEvaluation();
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
add(nextTask());
}
private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
return listener -> {
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery());
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
client.execute(
SearchAction.INSTANCE,
searchRequest,
ActionListener.wrap(
searchResponse -> {
evaluation.process(searchResponse);
if (evaluation.hasAllResults() == false) {
add(nextTask());
}
listener.onResponse(null);
},
listener::onFailure));
};
}
}
}