[7.x] Remove ClassInfo interface and BinaryClassInfo class. (#49649) (#49681)

This commit is contained in:
Przemysław Witek 2019-11-28 21:46:46 +01:00 committed by GitHub
parent 496bb9e2ee
commit 1425e30b1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 259 additions and 404 deletions

View File

@ -6,15 +6,23 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
/**
@ -27,37 +35,67 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
*/
String getName();
/**
* Returns the field containing the actual value
*/
String getActualField();
/**
* Returns the field containing the predicted value
*/
String getPredictedField();
/**
* Returns the list of metrics to evaluate
* @return list of metrics to evaluate
*/
List<? extends EvaluationMetric> getMetrics();
default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parsedMetrics, Supplier<List<T>> defaultMetricsSupplier) {
List<T> metrics = parsedMetrics == null ? defaultMetricsSupplier.get() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
}
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
return metrics;
}
/**
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);
/**
* Builds the search that verifies existence of required fields and applies user-provided query
* @param requiredFields fields that must exist
* @param userProvidedQueryBuilder user-provided query
*/
default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
for (String requiredField : requiredFields) {
boolQuery.filter(QueryBuilders.existsQuery(requiredField));
default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
// Verify existence of required fields
.filter(QueryBuilders.existsQuery(getActualField()))
.filter(QueryBuilders.existsQuery(getPredictedField()))
// Apply user-provided query
.filter(userProvidedQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
aggs.forEach(searchSourceBuilder::aggregation);
}
boolQuery.filter(userProvidedQueryBuilder);
return new SearchSourceBuilder().size(0).query(boolQuery);
return searchSourceBuilder;
}
/**
* Processes {@link SearchResponse} from the search action
* @param searchResponse response from the search action
*/
void process(SearchResponse searchResponse);
default void process(SearchResponse searchResponse) {
Objects.requireNonNull(searchResponse);
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
}
for (EvaluationMetric metric : getMetrics()) {
metric.process(searchResponse.getAggregations());
}
}
/**
* @return true iff all the metrics have their results computed

View File

@ -5,9 +5,13 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import java.util.List;
import java.util.Optional;
/**
@ -20,6 +24,20 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
*/
String getName();
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value (class name or probability)
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);
/**
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
void process(Aggregations aggs);
/**
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@ -55,13 +48,13 @@ public class Classification implements Evaluation {
/**
* The field containing the actual value
* The value of this field is assumed to be numeric
* The value of this field is assumed to be categorical
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
* The value of this field is assumed to be categorical
*/
private final String predictedField;
@ -73,7 +66,11 @@ public class Classification implements Evaluation {
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics);
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
}
private static List<ClassificationMetric> defaultMetrics() {
return Arrays.asList(new MulticlassConfusionMatrix());
}
public Classification(StreamInput in) throws IOException {
@ -82,52 +79,26 @@ public class Classification implements Evaluation {
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
}
private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
return metrics;
}
private static List<ClassificationMetric> defaultMetrics() {
return Arrays.asList(new MulticlassConfusionMatrix());
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public String getActualField() {
return actualField;
}
@Override
public String getPredictedField() {
return predictedField;
}
@Override
public List<ClassificationMetric> getMetrics() {
return metrics;
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder =
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
for (ClassificationMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}
@Override
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
}
for (ClassificationMetric metric : metrics) {
metric.process(searchResponse.getAggregations());
}
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();

View File

@ -5,26 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import java.util.List;
public interface ClassificationMetric extends EvaluationMetric {
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);
/**
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
void process(Aggregations aggs);
}

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@ -73,7 +66,11 @@ public class Regression implements Evaluation {
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics);
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
}
private static List<RegressionMetric> defaultMetrics() {
return Arrays.asList(new MeanSquaredError(), new RSquared());
}
public Regression(StreamInput in) throws IOException {
@ -82,52 +79,26 @@ public class Regression implements Evaluation {
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
}
private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
return metrics;
}
private static List<RegressionMetric> defaultMetrics() {
return Arrays.asList(new MeanSquaredError(), new RSquared());
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public String getActualField() {
return actualField;
}
@Override
public String getPredictedField() {
return predictedField;
}
@Override
public List<RegressionMetric> getMetrics() {
return metrics;
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder =
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
for (RegressionMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}
@Override
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
}
for (RegressionMetric metric : metrics) {
metric.process(searchResponse.getAggregations());
}
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();

View File

@ -5,26 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import java.util.List;
public interface RegressionMetric extends EvaluationMetric {
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);
/**
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
void process(Aggregations aggs);
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
@ -18,11 +19,12 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
public static final ParseField AT = new ParseField("at");
@ -30,8 +32,8 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
protected final double[] thresholds;
private EvaluationMetricResult result;
protected AbstractConfusionMatrixMetric(double[] thresholds) {
this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
protected AbstractConfusionMatrixMetric(List<Double> at) {
this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
if (thresholds.length == 0) {
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
}
@ -61,20 +63,16 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
}
@Override
public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
if (result != null) {
return Collections.emptyList();
}
List<AggregationBuilder> aggs = new ArrayList<>();
for (double threshold : thresholds) {
aggs.addAll(aggsAt(actualField, classInfos, threshold));
}
return aggs;
return aggsAt(actualField, predictedProbabilityField);
}
@Override
public void process(ClassInfo classInfo, Aggregations aggs) {
result = evaluate(classInfo, aggs);
public void process(Aggregations aggs) {
result = evaluate(aggs);
}
@Override
@ -82,40 +80,43 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
return Optional.ofNullable(result);
}
protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
protected abstract List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField);
protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
protected abstract EvaluationMetricResult evaluate(Aggregations aggs);
protected enum Condition {
TP, FP, TN, FN;
}
enum Condition {
TP(true, true),
FP(false, true),
TN(false, false),
FN(true, false);
protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
}
final boolean actual;
final boolean predicted;
protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
switch (condition) {
case TP:
boolQuery.must(classInfo.matchingQuery());
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
break;
case FP:
boolQuery.mustNot(classInfo.matchingQuery());
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
break;
case TN:
boolQuery.mustNot(classInfo.matchingQuery());
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
break;
case FN:
boolQuery.must(classInfo.matchingQuery());
boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
break;
default:
throw new IllegalArgumentException("Unknown enum value: " + condition);
Condition(boolean actual, boolean predicted) {
this.actual = actual;
this.predicted = predicted;
}
return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery);
}
protected String aggName(double threshold, Condition condition) {
return getName() + "_at_" + threshold + "_" + condition.name();
}
protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField);
QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold);
if (condition.actual) {
boolQuery.must(actualIsTrueQuery);
} else {
boolQuery.mustNot(actualIsTrueQuery);
}
if (condition.predicted) {
boolQuery.must(predictedIsTrueQuery);
} else {
boolQuery.mustNot(predictedIsTrueQuery);
}
return AggregationBuilders.filter(aggName(threshold, condition), boolQuery);
}
}

View File

@ -33,6 +33,8 @@ import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
/**
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
* The ROC curve is a plot of the TPR (true positive rate) against
@ -66,6 +68,9 @@ public class AucRoc implements SoftClassificationMetric {
private static final String PERCENTILES = "percentiles";
private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
public static AucRoc fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
@ -118,30 +123,39 @@ public class AucRoc implements SoftClassificationMetric {
}
@Override
public List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
public List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
if (result != null) {
return Collections.emptyList();
}
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
List<AggregationBuilder> aggs = new ArrayList<>();
for (ClassInfo classInfo : classInfos) {
AggregationBuilder percentilesForClassValueAgg = AggregationBuilders
.filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery())
AggregationBuilder percentilesForClassValueAgg =
AggregationBuilders
.filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField))
.subAggregation(
AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles));
AggregationBuilder percentilesForRestAgg = AggregationBuilders
.filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery()))
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
AggregationBuilder percentilesForRestAgg =
AggregationBuilders
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
.subAggregation(
AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles));
aggs.add(percentilesForClassValueAgg);
aggs.add(percentilesForRestAgg);
}
return aggs;
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
return Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg);
}
@Override
public void process(ClassInfo classInfo, Aggregations aggs) {
result = evaluate(classInfo, aggs);
public void process(Aggregations aggs) {
Filter classAgg = aggs.get(TRUE_AGG_NAME);
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
double[] tpPercentiles =
percentilesArray(
classAgg.getAggregations().get(PERCENTILES),
"[" + getName() + "] requires at least one actual_field to have the value [true]");
double[] fpPercentiles =
percentilesArray(
restAgg.getAggregations().get(PERCENTILES),
"[" + getName() + "] requires at least one actual_field to have a different value than [true]");
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
}
@Override
@ -149,26 +163,6 @@ public class AucRoc implements SoftClassificationMetric {
return Optional.ofNullable(result);
}
private String evaluatedLabelAggName(ClassInfo classInfo) {
return getName() + "_" + classInfo.getName();
}
private String restLabelsAggName(ClassInfo classInfo) {
return getName() + "_non_" + classInfo.getName();
}
private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
Filter restAgg = aggs.get(restLabelsAggName(classInfo));
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
"[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
"[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
}
private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
double[] result = new double[99];
percentiles.forEach(percentile -> {

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
@ -74,16 +67,7 @@ public class BinarySoftClassification implements Evaluation {
@Nullable List<SoftClassificationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
this.metrics = initMetrics(metrics);
}
private static List<SoftClassificationMetric> initMetrics(@Nullable List<SoftClassificationMetric> parsedMetrics) {
List<SoftClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
return metrics;
this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics);
}
private static List<SoftClassificationMetric> defaultMetrics() {
@ -100,6 +84,26 @@ public class BinarySoftClassification implements Evaluation {
this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public String getActualField() {
return actualField;
}
@Override
public String getPredictedField() {
return predictedProbabilityField;
}
@Override
public List<SoftClassificationMetric> getMetrics() {
return metrics;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -142,60 +146,4 @@ public class BinarySoftClassification implements Evaluation {
public int hashCode() {
return Objects.hash(actualField, predictedProbabilityField, metrics);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public List<SoftClassificationMetric> getMetrics() {
return metrics;
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder =
newSearchSourceBuilder(Arrays.asList(actualField, predictedProbabilityField), userProvidedQueryBuilder);
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
for (SoftClassificationMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}
@Override
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
}
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
for (SoftClassificationMetric metric : metrics) {
metric.process(binaryClassInfo, searchResponse.getAggregations());
}
}
private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {
private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
@Override
public String getName() {
return String.valueOf(true);
}
@Override
public QueryBuilder matchingQuery() {
return matchingQuery;
}
@Override
public String getProbabilityField() {
return predictedProbabilityField;
}
}
}

View File

@ -37,7 +37,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
}
public ConfusionMatrix(List<Double> at) {
super(at.stream().mapToDouble(Double::doubleValue).toArray());
super(at);
}
public ConfusionMatrix(StreamInput in) throws IOException {
@ -68,28 +68,29 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
}
@Override
protected List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold) {
protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
List<AggregationBuilder> aggs = new ArrayList<>();
for (ClassInfo classInfo : classInfos) {
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
aggs.add(buildAgg(classInfo, threshold, Condition.FP));
aggs.add(buildAgg(classInfo, threshold, Condition.TN));
aggs.add(buildAgg(classInfo, threshold, Condition.FN));
for (int i = 0; i < thresholds.length; i++) {
double threshold = thresholds[i];
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP));
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TN));
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
}
return aggs;
}
@Override
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
public EvaluationMetricResult evaluate(Aggregations aggs) {
long[] tp = new long[thresholds.length];
long[] fp = new long[thresholds.length];
long[] tn = new long[thresholds.length];
long[] fn = new long[thresholds.length];
for (int i = 0; i < thresholds.length; i++) {
Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP));
Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP));
Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN));
Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN));
Filter tpAgg = aggs.get(aggName(thresholds[i], Condition.TP));
Filter fpAgg = aggs.get(aggName(thresholds[i], Condition.FP));
Filter tnAgg = aggs.get(aggName(thresholds[i], Condition.TN));
Filter fnAgg = aggs.get(aggName(thresholds[i], Condition.FN));
tp[i] = tpAgg.getDocCount();
fp[i] = fpAgg.getDocCount();
tn[i] = tnAgg.getDocCount();

View File

@ -35,7 +35,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
}
public Precision(List<Double> at) {
super(at.stream().mapToDouble(Double::doubleValue).toArray());
super(at);
}
public Precision(StreamInput in) throws IOException {
@ -66,22 +66,23 @@ public class Precision extends AbstractConfusionMatrixMetric {
}
@Override
protected List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold) {
protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
List<AggregationBuilder> aggs = new ArrayList<>();
for (ClassInfo classInfo : classInfos) {
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
aggs.add(buildAgg(classInfo, threshold, Condition.FP));
for (int i = 0; i < thresholds.length; i++) {
double threshold = thresholds[i];
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP));
}
return aggs;
}
@Override
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
public EvaluationMetricResult evaluate(Aggregations aggs) {
double[] precisions = new double[thresholds.length];
for (int i = 0; i < precisions.length; i++) {
for (int i = 0; i < thresholds.length; i++) {
double threshold = thresholds[i];
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP));
Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
Filter fpAgg = aggs.get(aggName(threshold, Condition.FP));
long tp = tpAgg.getDocCount();
long fp = fpAgg.getDocCount();
precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp);

View File

@ -35,7 +35,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
}
public Recall(List<Double> at) {
super(at.stream().mapToDouble(Double::doubleValue).toArray());
super(at);
}
public Recall(StreamInput in) throws IOException {
@ -66,22 +66,23 @@ public class Recall extends AbstractConfusionMatrixMetric {
}
@Override
protected List<AggregationBuilder> aggsAt(String actualField, List<ClassInfo> classInfos, double threshold) {
protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
List<AggregationBuilder> aggs = new ArrayList<>();
for (ClassInfo classInfo : classInfos) {
aggs.add(buildAgg(classInfo, threshold, Condition.TP));
aggs.add(buildAgg(classInfo, threshold, Condition.FN));
for (int i = 0; i < thresholds.length; i++) {
double threshold = thresholds[i];
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
}
return aggs;
}
@Override
public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
public EvaluationMetricResult evaluate(Aggregations aggs) {
double[] recalls = new double[thresholds.length];
for (int i = 0; i < recalls.length; i++) {
for (int i = 0; i < thresholds.length; i++) {
double threshold = thresholds[i];
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
Filter fnAgg = aggs.get(aggName(threshold, Condition.FN));
long tp = tpAgg.getDocCount();
long fn = fnAgg.getDocCount();
recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);

View File

@ -5,49 +5,13 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import java.util.List;
public interface SoftClassificationMetric extends EvaluationMetric {
/**
* The information of a specific class
*/
interface ClassInfo {
/**
* Returns the class name
*/
String getName();
/**
* Returns a query that matches documents of the class
*/
QueryBuilder matchingQuery();
/**
* Returns the field that has the probability to be of the class
*/
String getProbabilityField();
static QueryBuilder actualIsTrueQuery(String actualField) {
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
}
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual class
* @param classInfos the information of each class to compute the metric for
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos);
/**
* Processes given aggregations as a step towards computing result
* @param classInfo the class to calculate the metric for
* @param aggs aggregations from {@link SearchResponse}
*/
void process(ClassInfo classInfo, Aggregations aggs);
}

View File

@ -49,22 +49,19 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
}
public void testEvaluate() {
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
when(classInfo.getName()).thenReturn("foo");
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("confusion_matrix_foo_at_0.25_TP", 1L),
createFilterAgg("confusion_matrix_foo_at_0.25_FP", 2L),
createFilterAgg("confusion_matrix_foo_at_0.25_TN", 3L),
createFilterAgg("confusion_matrix_foo_at_0.25_FN", 4L),
createFilterAgg("confusion_matrix_foo_at_0.5_TP", 5L),
createFilterAgg("confusion_matrix_foo_at_0.5_FP", 6L),
createFilterAgg("confusion_matrix_foo_at_0.5_TN", 7L),
createFilterAgg("confusion_matrix_foo_at_0.5_FN", 8L)
createFilterAgg("confusion_matrix_at_0.25_TP", 1L),
createFilterAgg("confusion_matrix_at_0.25_FP", 2L),
createFilterAgg("confusion_matrix_at_0.25_TN", 3L),
createFilterAgg("confusion_matrix_at_0.25_FN", 4L),
createFilterAgg("confusion_matrix_at_0.5_TP", 5L),
createFilterAgg("confusion_matrix_at_0.5_FP", 6L),
createFilterAgg("confusion_matrix_at_0.5_TN", 7L),
createFilterAgg("confusion_matrix_at_0.5_FN", 8L)
));
ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
EvaluationMetricResult result = confusionMatrix.evaluate(classInfo, aggs);
EvaluationMetricResult result = confusionMatrix.evaluate(aggs);
String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
assertThat(Strings.toString(result), equalTo(expected));

View File

@ -49,36 +49,30 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
}
public void testEvaluate() {
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
when(classInfo.getName()).thenReturn("foo");
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("precision_foo_at_0.25_TP", 1L),
createFilterAgg("precision_foo_at_0.25_FP", 4L),
createFilterAgg("precision_foo_at_0.5_TP", 3L),
createFilterAgg("precision_foo_at_0.5_FP", 1L),
createFilterAgg("precision_foo_at_0.75_TP", 5L),
createFilterAgg("precision_foo_at_0.75_FP", 0L)
createFilterAgg("precision_at_0.25_TP", 1L),
createFilterAgg("precision_at_0.25_FP", 4L),
createFilterAgg("precision_at_0.5_TP", 3L),
createFilterAgg("precision_at_0.5_FP", 1L),
createFilterAgg("precision_at_0.75_TP", 5L),
createFilterAgg("precision_at_0.75_FP", 0L)
));
Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
EvaluationMetricResult result = precision.evaluate(classInfo, aggs);
EvaluationMetricResult result = precision.evaluate(aggs);
String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}";
assertThat(Strings.toString(result), equalTo(expected));
}
public void testEvaluate_GivenZeroTpAndFp() {
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
when(classInfo.getName()).thenReturn("foo");
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("precision_foo_at_1.0_TP", 0L),
createFilterAgg("precision_foo_at_1.0_FP", 0L)
createFilterAgg("precision_at_1.0_TP", 0L),
createFilterAgg("precision_at_1.0_FP", 0L)
));
Precision precision = new Precision(Arrays.asList(1.0));
EvaluationMetricResult result = precision.evaluate(classInfo, aggs);
EvaluationMetricResult result = precision.evaluate(aggs);
String expected = "{\"1.0\":0.0}";
assertThat(Strings.toString(result), equalTo(expected));

View File

@ -49,36 +49,30 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
}
public void testEvaluate() {
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
when(classInfo.getName()).thenReturn("foo");
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("recall_foo_at_0.25_TP", 1L),
createFilterAgg("recall_foo_at_0.25_FN", 4L),
createFilterAgg("recall_foo_at_0.5_TP", 3L),
createFilterAgg("recall_foo_at_0.5_FN", 1L),
createFilterAgg("recall_foo_at_0.75_TP", 5L),
createFilterAgg("recall_foo_at_0.75_FN", 0L)
createFilterAgg("recall_at_0.25_TP", 1L),
createFilterAgg("recall_at_0.25_FN", 4L),
createFilterAgg("recall_at_0.5_TP", 3L),
createFilterAgg("recall_at_0.5_FN", 1L),
createFilterAgg("recall_at_0.75_TP", 5L),
createFilterAgg("recall_at_0.75_FN", 0L)
));
Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
EvaluationMetricResult result = recall.evaluate(classInfo, aggs);
EvaluationMetricResult result = recall.evaluate(aggs);
String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}";
assertThat(Strings.toString(result), equalTo(expected));
}
public void testEvaluate_GivenZeroTpAndFp() {
SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
when(classInfo.getName()).thenReturn("foo");
Aggregations aggs = new Aggregations(Arrays.asList(
createFilterAgg("recall_foo_at_1.0_TP", 0L),
createFilterAgg("recall_foo_at_1.0_FN", 0L)
createFilterAgg("recall_at_1.0_TP", 0L),
createFilterAgg("recall_at_1.0_FN", 0L)
));
Recall recall = new Recall(Arrays.asList(1.0));
EvaluationMetricResult result = recall.evaluate(classInfo, aggs);
EvaluationMetricResult result = recall.evaluate(aggs);
String expected = "{\"1.0\":0.0}";
assertThat(Strings.toString(result), equalTo(expected));