mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-22 12:56:53 +00:00
This commit is contained in:
parent
496bb9e2ee
commit
1425e30b1e
@ -6,15 +6,23 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@ -27,37 +35,67 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
||||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Returns the field containing the actual value
|
||||
*/
|
||||
String getActualField();
|
||||
|
||||
/**
|
||||
* Returns the field containing the predicted value
|
||||
*/
|
||||
String getPredictedField();
|
||||
|
||||
/**
|
||||
* Returns the list of metrics to evaluate
|
||||
* @return list of metrics to evaluate
|
||||
*/
|
||||
List<? extends EvaluationMetric> getMetrics();
|
||||
|
||||
default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parsedMetrics, Supplier<List<T>> defaultMetricsSupplier) {
|
||||
List<T> metrics = parsedMetrics == null ? defaultMetricsSupplier.get() : new ArrayList<>(parsedMetrics);
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the search required to collect data to compute the evaluation result
|
||||
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
|
||||
*/
|
||||
SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);
|
||||
|
||||
/**
|
||||
* Builds the search that verifies existence of required fields and applies user-provided query
|
||||
* @param requiredFields fields that must exist
|
||||
* @param userProvidedQueryBuilder user-provided query
|
||||
*/
|
||||
default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
||||
for (String requiredField : requiredFields) {
|
||||
boolQuery.filter(QueryBuilders.existsQuery(requiredField));
|
||||
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
Objects.requireNonNull(userProvidedQueryBuilder);
|
||||
BoolQueryBuilder boolQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
// Verify existence of required fields
|
||||
.filter(QueryBuilders.existsQuery(getActualField()))
|
||||
.filter(QueryBuilders.existsQuery(getPredictedField()))
|
||||
// Apply user-provided query
|
||||
.filter(userProvidedQueryBuilder);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (EvaluationMetric metric : getMetrics()) {
|
||||
// Fetch aggregations requested by individual metrics
|
||||
List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
boolQuery.filter(userProvidedQueryBuilder);
|
||||
return new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes {@link SearchResponse} from the search action
|
||||
* @param searchResponse response from the search action
|
||||
*/
|
||||
void process(SearchResponse searchResponse);
|
||||
default void process(SearchResponse searchResponse) {
|
||||
Objects.requireNonNull(searchResponse);
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
|
||||
}
|
||||
for (EvaluationMetric metric : getMetrics()) {
|
||||
metric.process(searchResponse.getAggregations());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return true iff all the metrics have their results computed
|
||||
|
@ -5,9 +5,13 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
@ -20,6 +24,20 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
||||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value (class name or probability)
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
void process(Aggregations aggs);
|
||||
|
||||
/**
|
||||
* Gets the evaluation result for this metric.
|
||||
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
|
||||
|
@ -5,7 +5,6 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@ -55,13 +48,13 @@ public class Classification implements Evaluation {
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
* The value of this field is assumed to be categorical
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
* The value of this field is assumed to be categorical
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
@ -73,7 +66,11 @@ public class Classification implements Evaluation {
|
||||
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.metrics = initMetrics(metrics);
|
||||
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
|
||||
}
|
||||
|
||||
private static List<ClassificationMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MulticlassConfusionMatrix());
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
@ -82,52 +79,26 @@ public class Classification implements Evaluation {
|
||||
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
|
||||
}
|
||||
|
||||
private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
|
||||
List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private static List<ClassificationMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MulticlassConfusionMatrix());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPredictedField() {
|
||||
return predictedField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
||||
SearchSourceBuilder searchSourceBuilder =
|
||||
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
|
||||
for (ClassificationMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(SearchResponse searchResponse) {
|
||||
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
|
||||
}
|
||||
for (ClassificationMetric metric : metrics) {
|
||||
metric.process(searchResponse.getAggregations());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -5,26 +5,7 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ClassificationMetric extends EvaluationMetric {
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
void process(Aggregations aggs);
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@ -73,7 +66,11 @@ public class Regression implements Evaluation {
|
||||
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.metrics = initMetrics(metrics);
|
||||
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MeanSquaredError(), new RSquared());
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
@ -82,52 +79,26 @@ public class Regression implements Evaluation {
|
||||
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
|
||||
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MeanSquaredError(), new RSquared());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPredictedField() {
|
||||
return predictedField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RegressionMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
||||
SearchSourceBuilder searchSourceBuilder =
|
||||
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
|
||||
for (RegressionMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(SearchResponse searchResponse) {
|
||||
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
|
||||
}
|
||||
for (RegressionMetric metric : metrics) {
|
||||
metric.process(searchResponse.getAggregations());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
|
@ -5,26 +5,7 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface RegressionMetric extends EvaluationMetric {
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
void process(Aggregations aggs);
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
@ -18,11 +19,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
|
||||
|
||||
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
|
||||
|
||||
public static final ParseField AT = new ParseField("at");
|
||||
@ -30,8 +32,8 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
||||
protected final double[] thresholds;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
protected AbstractConfusionMatrixMetric(double[] thresholds) {
|
||||
this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
|
||||
protected AbstractConfusionMatrixMetric(List<Double> at) {
|
||||
this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
|
||||
if (thresholds.length == 0) {
|
||||
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
|
||||
}
|
||||
@ -61,20 +63,16 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
|
||||
public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (double threshold : thresholds) {
|
||||
aggs.addAll(aggsAt(actualField, classInfos, threshold));
|
||||
}
|
||||
return aggs;
|
||||
return aggsAt(actualField, predictedProbabilityField);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ClassInfo classInfo, Aggregations aggs) {
|
||||
result = evaluate(classInfo, aggs);
|
||||
public void process(Aggregations aggs) {
|
||||
result = evaluate(aggs);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -82,40 +80,43 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
|
||||
protected abstract List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField);
|
||||
|
||||
protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
|
||||
protected abstract EvaluationMetricResult evaluate(Aggregations aggs);
|
||||
|
||||
protected enum Condition {
|
||||
TP, FP, TN, FN;
|
||||
}
|
||||
enum Condition {
|
||||
TP(true, true),
|
||||
FP(false, true),
|
||||
TN(false, false),
|
||||
FN(true, false);
|
||||
|
||||
protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
|
||||
return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
|
||||
}
|
||||
final boolean actual;
|
||||
final boolean predicted;
|
||||
|
||||
protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
||||
switch (condition) {
|
||||
case TP:
|
||||
boolQuery.must(classInfo.matchingQuery());
|
||||
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
|
||||
break;
|
||||
case FP:
|
||||
boolQuery.mustNot(classInfo.matchingQuery());
|
||||
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
|
||||
break;
|
||||
case TN:
|
||||
boolQuery.mustNot(classInfo.matchingQuery());
|
||||
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
|
||||
break;
|
||||
case FN:
|
||||
boolQuery.must(classInfo.matchingQuery());
|
||||
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unknown enum value: " + condition);
|
||||
Condition(boolean actual, boolean predicted) {
|
||||
this.actual = actual;
|
||||
this.predicted = predicted;
|
||||
}
|
||||
return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery);
|
||||
}
|
||||
|
||||
protected String aggName(double threshold, Condition condition) {
|
||||
return getName() + "_at_" + threshold + "_" + condition.name();
|
||||
}
|
||||
|
||||
protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
||||
QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField);
|
||||
QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold);
|
||||
if (condition.actual) {
|
||||
boolQuery.must(actualIsTrueQuery);
|
||||
} else {
|
||||
boolQuery.mustNot(actualIsTrueQuery);
|
||||
}
|
||||
if (condition.predicted) {
|
||||
boolQuery.must(predictedIsTrueQuery);
|
||||
} else {
|
||||
boolQuery.mustNot(predictedIsTrueQuery);
|
||||
}
|
||||
return AggregationBuilders.filter(aggName(threshold, condition), boolQuery);
|
||||
}
|
||||
}
|
||||
|
@ -33,6 +33,8 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
|
||||
|
||||
/**
|
||||
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
|
||||
* The ROC curve is a plot of the TPR (true positive rate) against
|
||||
@ -66,6 +68,9 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
|
||||
private static final String PERCENTILES = "percentiles";
|
||||
|
||||
private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
|
||||
private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
|
||||
|
||||
public static AucRoc fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
@ -118,30 +123,39 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (ClassInfo classInfo : classInfos) {
|
||||
AggregationBuilder percentilesForClassValueAgg = AggregationBuilders
|
||||
.filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery())
|
||||
AggregationBuilder percentilesForClassValueAgg =
|
||||
AggregationBuilders
|
||||
.filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField))
|
||||
.subAggregation(
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles));
|
||||
AggregationBuilder percentilesForRestAgg = AggregationBuilders
|
||||
.filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery()))
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
|
||||
AggregationBuilder percentilesForRestAgg =
|
||||
AggregationBuilders
|
||||
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
|
||||
.subAggregation(
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles));
|
||||
aggs.add(percentilesForClassValueAgg);
|
||||
aggs.add(percentilesForRestAgg);
|
||||
}
|
||||
return aggs;
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
|
||||
return Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ClassInfo classInfo, Aggregations aggs) {
|
||||
result = evaluate(classInfo, aggs);
|
||||
public void process(Aggregations aggs) {
|
||||
Filter classAgg = aggs.get(TRUE_AGG_NAME);
|
||||
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
|
||||
double[] tpPercentiles =
|
||||
percentilesArray(
|
||||
classAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getName() + "] requires at least one actual_field to have the value [true]");
|
||||
double[] fpPercentiles =
|
||||
percentilesArray(
|
||||
restAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getName() + "] requires at least one actual_field to have a different value than [true]");
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -149,26 +163,6 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
private String evaluatedLabelAggName(ClassInfo classInfo) {
|
||||
return getName() + "_" + classInfo.getName();
|
||||
}
|
||||
|
||||
private String restLabelsAggName(ClassInfo classInfo) {
|
||||
return getName() + "_non_" + classInfo.getName();
|
||||
}
|
||||
|
||||
private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
|
||||
Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
|
||||
Filter restAgg = aggs.get(restLabelsAggName(classInfo));
|
||||
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
|
||||
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
|
||||
}
|
||||
|
||||
private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
|
||||
double[] result = new double[99];
|
||||
percentiles.forEach(percentile -> {
|
||||
|
@ -5,7 +5,6 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@ -74,16 +67,7 @@ public class BinarySoftClassification implements Evaluation {
|
||||
@Nullable List<SoftClassificationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
|
||||
this.metrics = initMetrics(metrics);
|
||||
}
|
||||
|
||||
private static List<SoftClassificationMetric> initMetrics(@Nullable List<SoftClassificationMetric> parsedMetrics) {
|
||||
List<SoftClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
|
||||
return metrics;
|
||||
this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics);
|
||||
}
|
||||
|
||||
private static List<SoftClassificationMetric> defaultMetrics() {
|
||||
@ -100,6 +84,26 @@ public class BinarySoftClassification implements Evaluation {
|
||||
this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPredictedField() {
|
||||
return predictedProbabilityField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SoftClassificationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
@ -142,60 +146,4 @@ public class BinarySoftClassification implements Evaluation {
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedProbabilityField, metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SoftClassificationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
||||
SearchSourceBuilder searchSourceBuilder =
|
||||
newSearchSourceBuilder(Arrays.asList(actualField, predictedProbabilityField), userProvidedQueryBuilder);
|
||||
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(SearchResponse searchResponse) {
|
||||
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
|
||||
}
|
||||
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
metric.process(binaryClassInfo, searchResponse.getAggregations());
|
||||
}
|
||||
}
|
||||
|
||||
private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {
|
||||
|
||||
private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return String.valueOf(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryBuilder matchingQuery() {
|
||||
return matchingQuery;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getProbabilityField() {
|
||||
return predictedProbabilityField;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
||||
}
|
||||
|
||||
public ConfusionMatrix(List<Double> at) {
|
||||
super(at.stream().mapToDouble(Double::doubleValue).toArray());
|
||||
super(at);
|
||||
}
|
||||
|
||||
public ConfusionMatrix(StreamInput in) throws IOException {
|
||||
@ -68,28 +68,29 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold) {
|
||||
protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (ClassInfo classInfo : classInfos) {
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.FP));
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.TN));
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.FN));
|
||||
for (int i = 0; i < thresholds.length; i++) {
|
||||
double threshold = thresholds[i];
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP));
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TN));
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
|
||||
}
|
||||
return aggs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
long[] tp = new long[thresholds.length];
|
||||
long[] fp = new long[thresholds.length];
|
||||
long[] tn = new long[thresholds.length];
|
||||
long[] fn = new long[thresholds.length];
|
||||
for (int i = 0; i < thresholds.length; i++) {
|
||||
Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP));
|
||||
Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP));
|
||||
Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN));
|
||||
Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN));
|
||||
Filter tpAgg = aggs.get(aggName(thresholds[i], Condition.TP));
|
||||
Filter fpAgg = aggs.get(aggName(thresholds[i], Condition.FP));
|
||||
Filter tnAgg = aggs.get(aggName(thresholds[i], Condition.TN));
|
||||
Filter fnAgg = aggs.get(aggName(thresholds[i], Condition.FN));
|
||||
tp[i] = tpAgg.getDocCount();
|
||||
fp[i] = fpAgg.getDocCount();
|
||||
tn[i] = tnAgg.getDocCount();
|
||||
|
@ -35,7 +35,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
|
||||
}
|
||||
|
||||
public Precision(List<Double> at) {
|
||||
super(at.stream().mapToDouble(Double::doubleValue).toArray());
|
||||
super(at);
|
||||
}
|
||||
|
||||
public Precision(StreamInput in) throws IOException {
|
||||
@ -66,22 +66,23 @@ public class Precision extends AbstractConfusionMatrixMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold) {
|
||||
protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (ClassInfo classInfo : classInfos) {
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.FP));
|
||||
for (int i = 0; i < thresholds.length; i++) {
|
||||
double threshold = thresholds[i];
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP));
|
||||
}
|
||||
return aggs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
double[] precisions = new double[thresholds.length];
|
||||
for (int i = 0; i < precisions.length; i++) {
|
||||
for (int i = 0; i < thresholds.length; i++) {
|
||||
double threshold = thresholds[i];
|
||||
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
|
||||
Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP));
|
||||
Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
|
||||
Filter fpAgg = aggs.get(aggName(threshold, Condition.FP));
|
||||
long tp = tpAgg.getDocCount();
|
||||
long fp = fpAgg.getDocCount();
|
||||
precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp);
|
||||
|
@ -35,7 +35,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
|
||||
}
|
||||
|
||||
public Recall(List<Double> at) {
|
||||
super(at.stream().mapToDouble(Double::doubleValue).toArray());
|
||||
super(at);
|
||||
}
|
||||
|
||||
public Recall(StreamInput in) throws IOException {
|
||||
@ -66,22 +66,23 @@ public class Recall extends AbstractConfusionMatrixMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<AggregationBuilder> aggsAt(String actualField, List<ClassInfo> classInfos, double threshold) {
|
||||
protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
for (ClassInfo classInfo : classInfos) {
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
|
||||
aggs.add(buildAgg(classInfo, threshold, Condition.FN));
|
||||
for (int i = 0; i < thresholds.length; i++) {
|
||||
double threshold = thresholds[i];
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
|
||||
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
|
||||
}
|
||||
return aggs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
double[] recalls = new double[thresholds.length];
|
||||
for (int i = 0; i < recalls.length; i++) {
|
||||
for (int i = 0; i < thresholds.length; i++) {
|
||||
double threshold = thresholds[i];
|
||||
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
|
||||
Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
|
||||
Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
|
||||
Filter fnAgg = aggs.get(aggName(threshold, Condition.FN));
|
||||
long tp = tpAgg.getDocCount();
|
||||
long fn = fnAgg.getDocCount();
|
||||
recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);
|
||||
|
@ -5,49 +5,13 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface SoftClassificationMetric extends EvaluationMetric {
|
||||
|
||||
/**
|
||||
* The information of a specific class
|
||||
*/
|
||||
interface ClassInfo {
|
||||
|
||||
/**
|
||||
* Returns the class name
|
||||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Returns a query that matches documents of the class
|
||||
*/
|
||||
QueryBuilder matchingQuery();
|
||||
|
||||
/**
|
||||
* Returns the field that has the probability to be of the class
|
||||
*/
|
||||
String getProbabilityField();
|
||||
static QueryBuilder actualIsTrueQuery(String actualField) {
|
||||
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual class
|
||||
* @param classInfos the information of each class to compute the metric for
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param classInfo the class to calculate the metric for
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
void process(ClassInfo classInfo, Aggregations aggs);
|
||||
}
|
||||
|
@ -49,22 +49,19 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
|
||||
when(classInfo.getName()).thenReturn("foo");
|
||||
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createFilterAgg("confusion_matrix_foo_at_0.25_TP", 1L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.25_FP", 2L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.25_TN", 3L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.25_FN", 4L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.5_TP", 5L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.5_FP", 6L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.5_TN", 7L),
|
||||
createFilterAgg("confusion_matrix_foo_at_0.5_FN", 8L)
|
||||
createFilterAgg("confusion_matrix_at_0.25_TP", 1L),
|
||||
createFilterAgg("confusion_matrix_at_0.25_FP", 2L),
|
||||
createFilterAgg("confusion_matrix_at_0.25_TN", 3L),
|
||||
createFilterAgg("confusion_matrix_at_0.25_FN", 4L),
|
||||
createFilterAgg("confusion_matrix_at_0.5_TP", 5L),
|
||||
createFilterAgg("confusion_matrix_at_0.5_FP", 6L),
|
||||
createFilterAgg("confusion_matrix_at_0.5_TN", 7L),
|
||||
createFilterAgg("confusion_matrix_at_0.5_FN", 8L)
|
||||
));
|
||||
|
||||
ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
|
||||
EvaluationMetricResult result = confusionMatrix.evaluate(classInfo, aggs);
|
||||
EvaluationMetricResult result = confusionMatrix.evaluate(aggs);
|
||||
|
||||
String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
|
@ -49,36 +49,30 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
|
||||
when(classInfo.getName()).thenReturn("foo");
|
||||
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createFilterAgg("precision_foo_at_0.25_TP", 1L),
|
||||
createFilterAgg("precision_foo_at_0.25_FP", 4L),
|
||||
createFilterAgg("precision_foo_at_0.5_TP", 3L),
|
||||
createFilterAgg("precision_foo_at_0.5_FP", 1L),
|
||||
createFilterAgg("precision_foo_at_0.75_TP", 5L),
|
||||
createFilterAgg("precision_foo_at_0.75_FP", 0L)
|
||||
createFilterAgg("precision_at_0.25_TP", 1L),
|
||||
createFilterAgg("precision_at_0.25_FP", 4L),
|
||||
createFilterAgg("precision_at_0.5_TP", 3L),
|
||||
createFilterAgg("precision_at_0.5_FP", 1L),
|
||||
createFilterAgg("precision_at_0.75_TP", 5L),
|
||||
createFilterAgg("precision_at_0.75_FP", 0L)
|
||||
));
|
||||
|
||||
Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
|
||||
EvaluationMetricResult result = precision.evaluate(classInfo, aggs);
|
||||
EvaluationMetricResult result = precision.evaluate(aggs);
|
||||
|
||||
String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenZeroTpAndFp() {
|
||||
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
|
||||
when(classInfo.getName()).thenReturn("foo");
|
||||
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createFilterAgg("precision_foo_at_1.0_TP", 0L),
|
||||
createFilterAgg("precision_foo_at_1.0_FP", 0L)
|
||||
createFilterAgg("precision_at_1.0_TP", 0L),
|
||||
createFilterAgg("precision_at_1.0_FP", 0L)
|
||||
));
|
||||
|
||||
Precision precision = new Precision(Arrays.asList(1.0));
|
||||
EvaluationMetricResult result = precision.evaluate(classInfo, aggs);
|
||||
EvaluationMetricResult result = precision.evaluate(aggs);
|
||||
|
||||
String expected = "{\"1.0\":0.0}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
|
@ -49,36 +49,30 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
|
||||
when(classInfo.getName()).thenReturn("foo");
|
||||
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createFilterAgg("recall_foo_at_0.25_TP", 1L),
|
||||
createFilterAgg("recall_foo_at_0.25_FN", 4L),
|
||||
createFilterAgg("recall_foo_at_0.5_TP", 3L),
|
||||
createFilterAgg("recall_foo_at_0.5_FN", 1L),
|
||||
createFilterAgg("recall_foo_at_0.75_TP", 5L),
|
||||
createFilterAgg("recall_foo_at_0.75_FN", 0L)
|
||||
createFilterAgg("recall_at_0.25_TP", 1L),
|
||||
createFilterAgg("recall_at_0.25_FN", 4L),
|
||||
createFilterAgg("recall_at_0.5_TP", 3L),
|
||||
createFilterAgg("recall_at_0.5_FN", 1L),
|
||||
createFilterAgg("recall_at_0.75_TP", 5L),
|
||||
createFilterAgg("recall_at_0.75_FN", 0L)
|
||||
));
|
||||
|
||||
Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
|
||||
EvaluationMetricResult result = recall.evaluate(classInfo, aggs);
|
||||
EvaluationMetricResult result = recall.evaluate(aggs);
|
||||
|
||||
String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenZeroTpAndFp() {
|
||||
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
|
||||
when(classInfo.getName()).thenReturn("foo");
|
||||
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createFilterAgg("recall_foo_at_1.0_TP", 0L),
|
||||
createFilterAgg("recall_foo_at_1.0_FN", 0L)
|
||||
createFilterAgg("recall_at_1.0_TP", 0L),
|
||||
createFilterAgg("recall_at_1.0_FN", 0L)
|
||||
));
|
||||
|
||||
Recall recall = new Recall(Arrays.asList(1.0));
|
||||
EvaluationMetricResult result = recall.evaluate(classInfo, aggs);
|
||||
EvaluationMetricResult result = recall.evaluate(aggs);
|
||||
|
||||
String expected = "{\"1.0\":0.0}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
|
Loading…
x
Reference in New Issue
Block a user