parent
a1e2e208ce
commit
3fbd58d156
|
@ -105,28 +105,31 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
return indices;
|
||||
}
|
||||
|
||||
public final void setIndices(List<String> indices) {
|
||||
public final Request setIndices(List<String> indices) {
|
||||
ExceptionsHelper.requireNonNull(indices, INDEX);
|
||||
if (indices.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("At least one index must be specified");
|
||||
}
|
||||
this.indices = indices.toArray(new String[indices.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
public QueryBuilder getParsedQuery() {
|
||||
return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
|
||||
}
|
||||
|
||||
public final void setQueryProvider(QueryProvider queryProvider) {
|
||||
public final Request setQueryProvider(QueryProvider queryProvider) {
|
||||
this.queryProvider = queryProvider;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Evaluation getEvaluation() {
|
||||
return evaluation;
|
||||
}
|
||||
|
||||
public final void setEvaluation(Evaluation evaluation) {
|
||||
public final Request setEvaluation(Evaluation evaluation) {
|
||||
this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -203,6 +206,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
this.metrics = Objects.requireNonNull(metrics);
|
||||
}
|
||||
|
||||
public String getEvaluationName() {
|
||||
return evaluationName;
|
||||
}
|
||||
|
||||
public List<EvaluationMetricResult> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(evaluationName);
|
||||
|
@ -214,7 +225,7 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
|
|||
builder.startObject();
|
||||
builder.startObject(evaluationName);
|
||||
for (EvaluationMetricResult metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
builder.field(metric.getMetricName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
builder.endObject();
|
||||
|
|
|
@ -5,14 +5,17 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Defines an evaluation
|
||||
|
@ -25,15 +28,53 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
String getName();
|
||||
|
||||
/**
|
||||
* Builds the search required to collect data to compute the evaluation result
|
||||
* @param queryBuilder User-provided query that must be respected when collecting data
|
||||
* Returns the list of metrics to evaluate
|
||||
* @return list of metrics to evaluate
|
||||
*/
|
||||
SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
|
||||
List<? extends EvaluationMetric> getMetrics();
|
||||
|
||||
/**
|
||||
* Computes the evaluation result
|
||||
* @param searchResponse The search response required to compute the result
|
||||
* @param listener A listener of the results
|
||||
* Builds the search required to collect data to compute the evaluation result
|
||||
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
|
||||
*/
|
||||
void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener);
|
||||
SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);
|
||||
|
||||
/**
|
||||
* Builds the search that verifies existence of required fields and applies user-provided query
|
||||
* @param requiredFields fields that must exist
|
||||
* @param userProvidedQueryBuilder user-provided query
|
||||
*/
|
||||
default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
||||
for (String requiredField : requiredFields) {
|
||||
boolQuery.filter(QueryBuilders.existsQuery(requiredField));
|
||||
}
|
||||
boolQuery.filter(userProvidedQueryBuilder);
|
||||
return new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes {@link SearchResponse} from the search action
|
||||
* @param searchResponse response from the search action
|
||||
*/
|
||||
void process(SearchResponse searchResponse);
|
||||
|
||||
/**
|
||||
* @return true iff all the metrics have their results computed
|
||||
*/
|
||||
default boolean hasAllResults() {
|
||||
return getMetrics().stream().map(EvaluationMetric::getResult).allMatch(Optional::isPresent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the list of evaluation results
|
||||
* @return list of evaluation results
|
||||
*/
|
||||
default List<EvaluationMetricResult> getResults() {
|
||||
return getMetrics().stream()
|
||||
.map(EvaluationMetric::getResult)
|
||||
.filter(Optional::isPresent)
|
||||
.map(Optional::get)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* {@link EvaluationMetric} class represents a metric to evaluate.
|
||||
*/
|
||||
public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* Returns the name of the metric (which may differ to the writeable name)
|
||||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Gets the evaluation result for this metric.
|
||||
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
|
||||
*/
|
||||
Optional<EvaluationMetricResult> getResult();
|
||||
}
|
|
@ -14,7 +14,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
|
|||
public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* Returns the name of the metric
|
||||
* Returns the name of the metric (which may differ to the writeable name)
|
||||
*/
|
||||
String getName();
|
||||
String getMetricName();
|
||||
}
|
||||
|
|
|
@ -20,10 +20,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
|
|||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Calculates the mean squared error between two known numerical fields.
|
||||
|
@ -48,28 +50,34 @@ public class MeanSquaredError implements RegressionMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public MeanSquaredError(StreamInput in) {
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
}
|
||||
public MeanSquaredError(StreamInput in) {}
|
||||
|
||||
public MeanSquaredError() {
|
||||
|
||||
}
|
||||
public MeanSquaredError() {}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
public void process(Aggregations aggs) {
|
||||
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
|
||||
return value == null ? new Result(0.0) : new Result(value.value());
|
||||
result = value == null ? new Result(0.0) : new Result(value.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -121,7 +129,7 @@ public class MeanSquaredError implements RegressionMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,9 +23,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
|
|||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Calculates R-Squared between two known numerical fields.
|
||||
|
@ -53,36 +55,42 @@ public class RSquared implements RegressionMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public RSquared(StreamInput in) {
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
}
|
||||
public RSquared(StreamInput in) {}
|
||||
|
||||
public RSquared() {
|
||||
|
||||
}
|
||||
public RSquared() {}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
|
||||
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
public void process(Aggregations aggs) {
|
||||
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
|
||||
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
|
||||
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
|
||||
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
|
||||
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
|
||||
new Result(0.0) :
|
||||
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -132,7 +140,7 @@ public class RSquared implements RegressionMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
|
@ -14,17 +13,15 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
@ -86,19 +83,16 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
|
||||
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
|
||||
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName));
|
||||
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> defaultMetrics() {
|
||||
List<RegressionMetric> defaultMetrics = new ArrayList<>(2);
|
||||
defaultMetrics.add(new MeanSquaredError());
|
||||
defaultMetrics.add(new RSquared());
|
||||
return defaultMetrics;
|
||||
return Arrays.asList(new MeanSquaredError(), new RSquared());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -107,12 +101,15 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery(actualField))
|
||||
.filter(QueryBuilders.existsQuery(predictedField))
|
||||
.filter(queryBuilder);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
public List<RegressionMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
||||
SearchSourceBuilder searchSourceBuilder =
|
||||
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
|
||||
for (RegressionMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
|
@ -121,18 +118,14 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
|
||||
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
|
||||
public void process(SearchResponse searchResponse) {
|
||||
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
|
||||
actualField,
|
||||
predictedField));
|
||||
return;
|
||||
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
|
||||
}
|
||||
for (RegressionMetric metric : metrics) {
|
||||
results.add(metric.evaluate(searchResponse.getAggregations()));
|
||||
metric.process(searchResponse.getAggregations());
|
||||
}
|
||||
listener.onResponse(results);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5,20 +5,14 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface RegressionMetric extends ToXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* Returns the name of the metric (which may differ to the writeable name)
|
||||
*/
|
||||
String getMetricName();
|
||||
public interface RegressionMetric extends EvaluationMetric {
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
|
@ -29,9 +23,8 @@ public interface RegressionMetric extends ToXContentObject, NamedWriteable {
|
|||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Calculates the metric result
|
||||
* @param aggs the aggregations
|
||||
* @return the metric result
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
EvaluationMetricResult evaluate(Aggregations aggs);
|
||||
void process(Aggregations aggs);
|
||||
}
|
||||
|
|
|
@ -13,27 +13,31 @@ import org.elasticsearch.index.query.BoolQueryBuilder;
|
|||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
|
||||
|
||||
public static final ParseField AT = new ParseField("at");
|
||||
|
||||
protected final double[] thresholds;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
protected AbstractConfusionMatrixMetric(double[] thresholds) {
|
||||
this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
|
||||
if (thresholds.length == 0) {
|
||||
throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName()
|
||||
+ "] must have at least one value");
|
||||
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
|
||||
}
|
||||
for (double threshold : thresholds) {
|
||||
if (threshold < 0 || threshold > 1.0) {
|
||||
throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName()
|
||||
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName()
|
||||
+ "] values must be in [0.0, 1.0]");
|
||||
}
|
||||
}
|
||||
|
@ -58,6 +62,9 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
|||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (double threshold : thresholds) {
|
||||
aggs.addAll(aggsAt(actualField, classInfos, threshold));
|
||||
|
@ -65,14 +72,26 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
|||
return aggs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ClassInfo classInfo, Aggregations aggs) {
|
||||
result = evaluate(classInfo, aggs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
|
||||
|
||||
protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
|
||||
|
||||
protected enum Condition {
|
||||
TP, FP, TN, FN;
|
||||
}
|
||||
|
||||
protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
|
||||
return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
|
||||
return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
|
||||
}
|
||||
|
||||
protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {
|
||||
|
|
|
@ -30,6 +30,7 @@ import java.util.Collections;
|
|||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
|
@ -70,6 +71,7 @@ public class AucRoc implements SoftClassificationMetric {
|
|||
}
|
||||
|
||||
private final boolean includeCurve;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public AucRoc(Boolean includeCurve) {
|
||||
this.includeCurve = includeCurve == null ? false : includeCurve;
|
||||
|
@ -98,7 +100,7 @@ public class AucRoc implements SoftClassificationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
@ -117,6 +119,9 @@ public class AucRoc implements SoftClassificationMetric {
|
|||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (ClassInfo classInfo : classInfos) {
|
||||
|
@ -134,22 +139,31 @@ public class AucRoc implements SoftClassificationMetric {
|
|||
return aggs;
|
||||
}
|
||||
|
||||
private String evaluatedLabelAggName(ClassInfo classInfo) {
|
||||
return getMetricName() + "_" + classInfo.getName();
|
||||
}
|
||||
|
||||
private String restLabelsAggName(ClassInfo classInfo) {
|
||||
return getMetricName() + "_non_" + classInfo.getName();
|
||||
@Override
|
||||
public void process(ClassInfo classInfo, Aggregations aggs) {
|
||||
result = evaluate(classInfo, aggs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
private String evaluatedLabelAggName(ClassInfo classInfo) {
|
||||
return getName() + "_" + classInfo.getName();
|
||||
}
|
||||
|
||||
private String restLabelsAggName(ClassInfo classInfo) {
|
||||
return getName() + "_non_" + classInfo.getName();
|
||||
}
|
||||
|
||||
private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
|
||||
Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
|
||||
Filter restAgg = aggs.get(restLabelsAggName(classInfo));
|
||||
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
|
||||
"[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
|
||||
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
|
||||
"[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
|
||||
|
@ -326,7 +340,7 @@ public class AucRoc implements SoftClassificationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
|
@ -14,18 +13,14 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
|
@ -87,17 +82,16 @@ public class BinarySoftClassification implements Evaluation {
|
|||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName));
|
||||
Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private static List<SoftClassificationMetric> defaultMetrics() {
|
||||
List<SoftClassificationMetric> defaultMetrics = new ArrayList<>(4);
|
||||
defaultMetrics.add(new AucRoc(false));
|
||||
defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75)));
|
||||
defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75)));
|
||||
defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
|
||||
return defaultMetrics;
|
||||
return Arrays.asList(
|
||||
new AucRoc(false),
|
||||
new Precision(Arrays.asList(0.25, 0.5, 0.75)),
|
||||
new Recall(Arrays.asList(0.25, 0.5, 0.75)),
|
||||
new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
|
||||
}
|
||||
|
||||
public BinarySoftClassification(StreamInput in) throws IOException {
|
||||
|
@ -126,7 +120,7 @@ public class BinarySoftClassification implements Evaluation {
|
|||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
builder.field(metric.getMetricName(), metric);
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
|
||||
|
@ -155,34 +149,34 @@ public class BinarySoftClassification implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery(actualField))
|
||||
.filter(QueryBuilders.existsQuery(predictedProbabilityField))
|
||||
.filter(queryBuilder);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
public List<SoftClassificationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
||||
SearchSourceBuilder searchSourceBuilder =
|
||||
newSearchSourceBuilder(Arrays.asList(actualField, predictedProbabilityField), userProvidedQueryBuilder);
|
||||
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
|
||||
public void process(SearchResponse searchResponse) {
|
||||
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField,
|
||||
predictedProbabilityField));
|
||||
return;
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
|
||||
}
|
||||
|
||||
List<EvaluationMetricResult> results = new ArrayList<>();
|
||||
Aggregations aggs = searchResponse.getAggregations();
|
||||
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
results.add(metric.evaluate(binaryClassInfo, aggs));
|
||||
metric.process(binaryClassInfo, searchResponse.getAggregations());
|
||||
}
|
||||
listener.onResponse(results);
|
||||
}
|
||||
|
||||
private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {
|
||||
|
|
|
@ -50,7 +50,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
@ -132,7 +132,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ public class ScoreByThresholdResult implements EvaluationMetricResult {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
public String getMetricName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
|
|
|
@ -5,16 +5,15 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable {
|
||||
public interface SoftClassificationMetric extends EvaluationMetric {
|
||||
|
||||
/**
|
||||
* The information of a specific class
|
||||
|
@ -37,11 +36,6 @@ public interface SoftClassificationMetric extends ToXContentObject, NamedWriteab
|
|||
String getProbabilityField();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the name of the metric (which may differ to the writeable name)
|
||||
*/
|
||||
String getMetricName();
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual class
|
||||
|
@ -51,10 +45,9 @@ public interface SoftClassificationMetric extends ToXContentObject, NamedWriteab
|
|||
List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos);
|
||||
|
||||
/**
|
||||
* Calculates the metric result for a given class
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param classInfo the class to calculate the metric for
|
||||
* @param aggs the aggregations
|
||||
* @return the metric result
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
|
||||
void process(ClassInfo classInfo, Aggregations aggs);
|
||||
}
|
||||
|
|
|
@ -49,8 +49,9 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
|||
));
|
||||
|
||||
MeanSquaredError mse = new MeanSquaredError();
|
||||
EvaluationMetricResult result = mse.evaluate(aggs);
|
||||
mse.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = mse.getResult().get();
|
||||
String expected = "{\"error\":0.8123}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
@ -61,7 +62,9 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
|||
));
|
||||
|
||||
MeanSquaredError mse = new MeanSquaredError();
|
||||
EvaluationMetricResult result = mse.evaluate(aggs);
|
||||
mse.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = mse.getResult().get();
|
||||
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
|
||||
}
|
||||
|
||||
|
|
|
@ -52,8 +52,9 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
|||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
rSquared.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = rSquared.getResult().get();
|
||||
String expected = "{\"value\":0.9348643947690524}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
@ -67,35 +68,48 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
|||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
rSquared.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = rSquared.getResult().get();
|
||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingAggs() {
|
||||
EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
|
||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
assertThat(result, equalTo(zeroResult));
|
||||
rSquared.process(aggs);
|
||||
|
||||
aggs = new Aggregations(Arrays.asList(
|
||||
EvaluationMetricResult result = rSquared.getResult().get();
|
||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingExtendedStatsAgg() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
||||
createSingleMetricAgg("residual_sum_of_squares", 0.2377)
|
||||
));
|
||||
|
||||
result = rSquared.evaluate(aggs);
|
||||
assertThat(result, equalTo(zeroResult));
|
||||
RSquared rSquared = new RSquared();
|
||||
rSquared.process(aggs);
|
||||
|
||||
aggs = new Aggregations(Arrays.asList(
|
||||
EvaluationMetricResult result = rSquared.getResult().get();
|
||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
||||
createExtendedStatsAgg("extended_stats_actual",100, 50)
|
||||
));
|
||||
|
||||
result = rSquared.evaluate(aggs);
|
||||
assertThat(result, equalTo(zeroResult));
|
||||
RSquared rSquared = new RSquared();
|
||||
rSquared.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = rSquared.getResult().get();
|
||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
|
@ -22,6 +23,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
|
@ -43,13 +45,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
if (randomBoolean()) {
|
||||
metrics.add(RSquaredTests.createRandom());
|
||||
}
|
||||
return new Regression(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
metrics.isEmpty() ?
|
||||
null :
|
||||
metrics);
|
||||
return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,7 +70,6 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError()));
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
|
@ -82,10 +77,15 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(QueryBuilders.existsQuery("prob"))
|
||||
.filter(QueryBuilders.existsQuery("pred"))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
|
||||
|
||||
Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
|
@ -22,6 +23,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
||||
public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
|
||||
|
||||
|
@ -81,7 +83,6 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
|||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
|
@ -93,6 +94,11 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
|||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
|
||||
|
||||
BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,12 +12,13 @@ import org.elasticsearch.action.support.ActionFilters;
|
|||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -38,24 +39,64 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
|
|||
@Override
|
||||
protected void doExecute(Task task, EvaluateDataFrameAction.Request request,
|
||||
ActionListener<EvaluateDataFrameAction.Response> listener) {
|
||||
Evaluation evaluation = request.getEvaluation();
|
||||
SearchRequest searchRequest = new SearchRequest(request.getIndices());
|
||||
searchRequest.source(evaluation.buildSearch(request.getParsedQuery()));
|
||||
|
||||
ActionListener<List<EvaluationMetricResult>> resultsListener = ActionListener.wrap(
|
||||
results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)),
|
||||
ActionListener<List<Void>> resultsListener = ActionListener.wrap(
|
||||
unused -> {
|
||||
EvaluateDataFrameAction.Response response =
|
||||
new EvaluateDataFrameAction.Response(request.getEvaluation().getName(), request.getEvaluation().getResults());
|
||||
listener.onResponse(response);
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
|
||||
searchResponse -> threadPool.generic().execute(() -> {
|
||||
try {
|
||||
evaluation.evaluate(searchResponse, resultsListener);
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(e);
|
||||
EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request);
|
||||
evaluationExecutor.execute(resultsListener);
|
||||
}
|
||||
|
||||
/**
|
||||
* {@link EvaluationExecutor} class allows for serial execution of evaluation steps.
|
||||
*
|
||||
* Each step consists of the following phases:
|
||||
* 1. build search request with aggs requested by individual metrics
|
||||
* 2. execute search action with the request built in (1.)
|
||||
* 3. make all individual metrics process the search response obtained in (2.)
|
||||
* 4. check if all the metrics have their results computed
|
||||
* a) If so, call the final listener and finish
|
||||
* b) Otherwise, add another step to the queue
|
||||
*
|
||||
* To avoid infinite loop it is essential that every metric *does* compute its result at some point.
|
||||
* */
|
||||
private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
|
||||
|
||||
private final Client client;
|
||||
private final EvaluateDataFrameAction.Request request;
|
||||
private final Evaluation evaluation;
|
||||
|
||||
EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) {
|
||||
super(threadPool.generic(), unused -> true, unused -> true);
|
||||
this.client = client;
|
||||
this.request = request;
|
||||
this.evaluation = request.getEvaluation();
|
||||
// Add one task only. Other tasks will be added as needed by the nextTask method itself.
|
||||
add(nextTask());
|
||||
}
|
||||
|
||||
private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
|
||||
return listener -> {
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery());
|
||||
SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
|
||||
client.execute(
|
||||
SearchAction.INSTANCE,
|
||||
searchRequest,
|
||||
ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
evaluation.process(searchResponse);
|
||||
if (evaluation.hasAllResults() == false) {
|
||||
add(nextTask());
|
||||
}
|
||||
listener.onResponse(null);
|
||||
},
|
||||
listener::onFailure));
|
||||
};
|
||||
}),
|
||||
listener::onFailure
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue