This commit is contained in:
parent
179fe9cc0e
commit
4366d58564
|
@ -88,7 +88,7 @@ the probability that each document is an outlier.
|
|||
`auc_roc`:::
|
||||
(Optional, object) The AUC ROC (area under the curve of the receiver
|
||||
operating characteristic) score and optionally the curve. Default value is
|
||||
{"includes_curve": false}.
|
||||
{"include_curve": false}.
|
||||
|
||||
`confusion_matrix`:::
|
||||
(Optional, object) Set the different thresholds of the {olscore} at where
|
||||
|
@ -153,9 +153,14 @@ belongs.
|
|||
The data type of this field must be categorical.
|
||||
|
||||
`predicted_field`::
|
||||
(Required, string) The field in the `index` that contains the predicted value,
|
||||
(Optional, string) The field in the `index` which contains the predicted value,
|
||||
in other words the results of the {classanalysis}.
|
||||
|
||||
`top_classes_field`::
|
||||
(Optional, string) The field of the `index` which is an array of documents
|
||||
of the form `{ "class_name": XXX, "class_probability": YYY }`.
|
||||
This field must be defined as `nested` in the mappings.
|
||||
|
||||
`metrics`::
|
||||
(Optional, object) Specifies the metrics that are used for the evaluation.
|
||||
Available metrics:
|
||||
|
@ -163,6 +168,24 @@ belongs.
|
|||
`accuracy`:::
|
||||
(Optional, object) Accuracy of predictions (per-class and overall).
|
||||
|
||||
`auc_roc`:::
|
||||
(Optional, object) The AUC ROC (area under the curve of the receiver
|
||||
operating characteristic) score and optionally the curve.
|
||||
It is calculated for a specific class (provided as "class_name")
|
||||
treated as positive.
|
||||
|
||||
`class_name`::::
|
||||
(Required, string) Name of the only class that will be treated as
|
||||
positive during AUC ROC calculation. Other classes will be treated as
|
||||
negative ("one-vs-all" strategy). Documents which do not have `class_name`
|
||||
in the list of their top classes will not be taken into account for evaluation.
|
||||
The number of documents taken into account is returned in the evaluation result
|
||||
(`auc_roc.doc_count` field).
|
||||
|
||||
`include_curve`::::
|
||||
(Optional, boolean) Whether or not the curve should be returned in
|
||||
addition to the score. Default value is false.
|
||||
|
||||
`multiclass_confusion_matrix`:::
|
||||
(Optional, object) Multiclass confusion matrix.
|
||||
|
||||
|
|
|
@ -394,7 +394,16 @@ public class Classification implements DataFrameAnalysis {
|
|||
return additionalProperties;
|
||||
}
|
||||
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
|
||||
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
|
||||
|
||||
Map<String, Object> topClassesProperties = new HashMap<>();
|
||||
topClassesProperties.put("class_name", dependentVariableMapping);
|
||||
topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
|
||||
|
||||
Map<String, Object> topClassesMapping = new HashMap<>();
|
||||
topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
|
||||
topClassesMapping.put("properties", topClassesProperties);
|
||||
|
||||
additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping);
|
||||
return additionalProperties;
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.apache.lucene.search.join.ScoreMode;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
|
@ -21,11 +22,16 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static java.util.stream.Collectors.joining;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static java.util.stream.Collectors.toSet;
|
||||
|
||||
/**
|
||||
* Defines an evaluation
|
||||
|
@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
String getName();
|
||||
|
||||
/**
|
||||
* Returns the field containing the actual value
|
||||
* Returns the collection of fields required by evaluation
|
||||
*/
|
||||
String getActualField();
|
||||
|
||||
/**
|
||||
* Returns the field containing the predicted value
|
||||
*/
|
||||
String getPredictedField();
|
||||
EvaluationFields getFields();
|
||||
|
||||
/**
|
||||
* Returns the list of metrics to evaluate
|
||||
|
@ -59,27 +60,74 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
|
||||
checkRequiredFieldsAreSet(metrics);
|
||||
return metrics;
|
||||
}
|
||||
|
||||
default <T extends EvaluationMetric> void checkRequiredFieldsAreSet(List<T> metrics) {
|
||||
assert (metrics == null || metrics.isEmpty()) == false;
|
||||
for (Tuple<String, String> requiredField : getFields().listPotentiallyRequiredFields()) {
|
||||
String fieldDescriptor = requiredField.v1();
|
||||
String field = requiredField.v2();
|
||||
if (field == null) {
|
||||
String metricNamesString =
|
||||
metrics.stream()
|
||||
.filter(m -> m.getRequiredFields().contains(fieldDescriptor))
|
||||
.map(EvaluationMetric::getName)
|
||||
.collect(joining(", "));
|
||||
if (metricNamesString.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] must define [{}] as required by the following metrics [{}]",
|
||||
getName(), fieldDescriptor, metricNamesString);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the search required to collect data to compute the evaluation result
|
||||
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
|
||||
*/
|
||||
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
|
||||
Objects.requireNonNull(userProvidedQueryBuilder);
|
||||
BoolQueryBuilder boolQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
// Verify existence of required fields
|
||||
.filter(QueryBuilders.existsQuery(getActualField()))
|
||||
.filter(QueryBuilders.existsQuery(getPredictedField()))
|
||||
// Apply user-provided query
|
||||
.filter(userProvidedQueryBuilder);
|
||||
Set<String> requiredFields = new HashSet<>(getRequiredFields());
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
||||
if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
|
||||
// Verify existence of the actual field if required
|
||||
boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
|
||||
}
|
||||
if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
|
||||
// Verify existence of the predicted field if required
|
||||
boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
|
||||
}
|
||||
if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
|
||||
assert getFields().getTopClassesField() != null;
|
||||
// Verify existence of the predicted class name field if required
|
||||
QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
|
||||
boolQuery.filter(
|
||||
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
|
||||
.ignoreUnmapped(true));
|
||||
}
|
||||
if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
|
||||
// Verify existence of the predicted probability field if required
|
||||
QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
|
||||
// predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
|
||||
// in case of outlier detection evaluation). Here we support both modes.
|
||||
if (getFields().isPredictedProbabilityFieldNested()) {
|
||||
assert getFields().getTopClassesField() != null;
|
||||
boolQuery.filter(
|
||||
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
|
||||
.ignoreUnmapped(true));
|
||||
} else {
|
||||
boolQuery.filter(predictedProbabilityFieldExistsQuery);
|
||||
}
|
||||
}
|
||||
// Apply user-provided query
|
||||
boolQuery.filter(userProvidedQueryBuilder);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (EvaluationMetric metric : getMetrics()) {
|
||||
// Fetch aggregations requested by individual metrics
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
|
||||
metric.aggs(parameters, getActualField(), getPredictedField());
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(parameters, getFields());
|
||||
aggs.v1().forEach(searchSourceBuilder::aggregation);
|
||||
aggs.v2().forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
|
@ -93,14 +141,31 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
default void process(SearchResponse searchResponse) {
|
||||
Objects.requireNonNull(searchResponse);
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
|
||||
String requiredFieldsString = String.join(", ", getRequiredFields());
|
||||
throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
|
||||
}
|
||||
for (EvaluationMetric metric : getMetrics()) {
|
||||
metric.process(searchResponse.getAggregations());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return list of fields which are required by at least one of the metrics
|
||||
*/
|
||||
default List<String> getRequiredFields() {
|
||||
Set<String> requiredFieldDescriptors =
|
||||
getMetrics().stream()
|
||||
.map(EvaluationMetric::getRequiredFields)
|
||||
.flatMap(Set::stream)
|
||||
.collect(toSet());
|
||||
List<String> requiredFields =
|
||||
getFields().listPotentiallyRequiredFields().stream()
|
||||
.filter(f -> requiredFieldDescriptors.contains(f.v1()))
|
||||
.map(Tuple::v2)
|
||||
.collect(toList());
|
||||
return requiredFields;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return true iff all the metrics have their results computed
|
||||
*/
|
||||
|
@ -117,6 +182,6 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
|||
.map(EvaluationMetric::getResult)
|
||||
.filter(Optional::isPresent)
|
||||
.map(Optional::get)
|
||||
.collect(Collectors.toList());
|
||||
.collect(toList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Encapsulates fields needed by evaluation.
|
||||
*/
|
||||
public final class EvaluationFields {
|
||||
|
||||
public static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
|
||||
public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field");
|
||||
public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
/**
|
||||
* The field containing the array of top classes
|
||||
*/
|
||||
private final String topClassesField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted class name value
|
||||
*/
|
||||
private final String predictedClassField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted probability value in [0.0, 1.0]
|
||||
*/
|
||||
private final String predictedProbabilityField;
|
||||
|
||||
/**
|
||||
* Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
|
||||
*/
|
||||
private final boolean predictedProbabilityFieldNested;
|
||||
|
||||
public EvaluationFields(@Nullable String actualField,
|
||||
@Nullable String predictedField,
|
||||
@Nullable String topClassesField,
|
||||
@Nullable String predictedClassField,
|
||||
@Nullable String predictedProbabilityField,
|
||||
boolean predictedProbabilityFieldNested) {
|
||||
|
||||
this.actualField = actualField;
|
||||
this.predictedField = predictedField;
|
||||
this.topClassesField = topClassesField;
|
||||
this.predictedClassField = predictedClassField;
|
||||
this.predictedProbabilityField = predictedProbabilityField;
|
||||
this.predictedProbabilityFieldNested = predictedProbabilityFieldNested;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the field containing the actual value
|
||||
*/
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the field containing the predicted value
|
||||
*/
|
||||
public String getPredictedField() {
|
||||
return predictedField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the field containing the array of top classes
|
||||
*/
|
||||
public String getTopClassesField() {
|
||||
return topClassesField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the field containing the predicted class name value
|
||||
*/
|
||||
public String getPredictedClassField() {
|
||||
return predictedClassField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the field containing the predicted probability value in [0.0, 1.0]
|
||||
*/
|
||||
public String getPredictedProbabilityField() {
|
||||
return predictedProbabilityField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
|
||||
*/
|
||||
public boolean isPredictedProbabilityFieldNested() {
|
||||
return predictedProbabilityFieldNested;
|
||||
}
|
||||
|
||||
public List<Tuple<String, String>> listPotentiallyRequiredFields() {
|
||||
return Arrays.asList(
|
||||
Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField),
|
||||
Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField),
|
||||
Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField),
|
||||
Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField),
|
||||
Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
EvaluationFields that = (EvaluationFields) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
&& Objects.equals(that.topClassesField, this.topClassesField)
|
||||
&& Objects.equals(that.predictedClassField, this.predictedClassField)
|
||||
&& Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField)
|
||||
&& Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested);
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* {@link EvaluationMetric} class represents a metric to evaluate.
|
||||
|
@ -26,16 +27,18 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
|||
*/
|
||||
String getName();
|
||||
|
||||
/**
|
||||
* Returns the set of fields that this metric requires in order to be calculated.
|
||||
*/
|
||||
Set<String> getRequiredFields();
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param parameters settings that may be needed by aggregations
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value (class name or probability)
|
||||
* @param fields fields that may be needed by aggregations
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField);
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters, EvaluationFields fields);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
|
|
|
@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
|||
/**
|
||||
* Encapsulates parameters needed by evaluation.
|
||||
*/
|
||||
public class EvaluationParameters {
|
||||
public final class EvaluationParameters {
|
||||
|
||||
/**
|
||||
* Maximum number of buckets allowed in any single search request.
|
||||
|
|
|
@ -10,13 +10,13 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ScoreByThresholdResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
|
@ -63,19 +63,28 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
|
||||
// Outlier detection metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRoc.NAME)),
|
||||
AucRoc::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, Precision.NAME)),
|
||||
Precision::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, Recall.NAME)),
|
||||
Recall::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME)),
|
||||
ConfusionMatrix::fromXContent),
|
||||
|
||||
// Classification metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AucRoc.NAME)),
|
||||
AucRoc::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)),
|
||||
MulticlassConfusionMatrix::fromXContent),
|
||||
|
@ -83,15 +92,11 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)),
|
||||
Accuracy::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent),
|
||||
new ParseField(registeredMetricName(Classification.NAME, Precision.NAME)),
|
||||
Precision::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent),
|
||||
new ParseField(registeredMetricName(Classification.NAME, Recall.NAME)),
|
||||
Recall::fromXContent),
|
||||
|
||||
// Regression metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
|
@ -124,17 +129,23 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
|
||||
// Evaluation metrics
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(OutlierDetection.NAME, AucRoc.NAME),
|
||||
AucRoc::new),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(OutlierDetection.NAME, Precision.NAME),
|
||||
Precision::new),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(OutlierDetection.NAME, Recall.NAME),
|
||||
Recall::new),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME),
|
||||
ConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Classification.NAME, AucRoc.NAME),
|
||||
AucRoc::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
|
||||
MulticlassConfusionMatrix::new),
|
||||
|
@ -142,13 +153,11 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
registeredMetricName(Classification.NAME, Accuracy.NAME),
|
||||
Accuracy::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new),
|
||||
registeredMetricName(Classification.NAME, Precision.NAME),
|
||||
Precision::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new),
|
||||
registeredMetricName(Classification.NAME, Recall.NAME),
|
||||
Recall::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
||||
MeanSquaredError::new),
|
||||
|
@ -163,15 +172,15 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
RSquared::new),
|
||||
|
||||
// Evaluation metrics results
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(OutlierDetection.NAME, AucRoc.NAME),
|
||||
AucRoc.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(OutlierDetection.NAME, ScoreByThresholdResult.NAME),
|
||||
ScoreByThresholdResult::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME),
|
||||
ConfusionMatrix.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Classification.NAME, AucRoc.NAME),
|
||||
AucRoc.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
|
||||
MulticlassConfusionMatrix.Result::new),
|
||||
|
@ -179,13 +188,11 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
registeredMetricName(Classification.NAME, Accuracy.NAME),
|
||||
Accuracy.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new),
|
||||
registeredMetricName(Classification.NAME, Precision.NAME),
|
||||
Precision.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new),
|
||||
registeredMetricName(Classification.NAME, Recall.NAME),
|
||||
Recall.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
||||
MeanSquaredError.Result::new),
|
||||
|
|
|
@ -0,0 +1,317 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
|
||||
* The ROC curve is a plot of the TPR (true positive rate) against
|
||||
* the FPR (false positive rate) over a varying threshold.
|
||||
*
|
||||
* This particular implementation is making use of ES aggregations
|
||||
* to calculate the curve. It then uses the trapezoidal rule to calculate
|
||||
* the AUC.
|
||||
*
|
||||
* In particular, in order to calculate the ROC, we get percentiles of TP
|
||||
* and FP against the predicted probability. We call those Rate-Threshold
|
||||
* curves. We then scan ROC points from each Rate-Threshold curve against the
|
||||
* other using interpolation. This gives us an approximation of the ROC curve
|
||||
* that has the advantage of being efficient and resilient to some edge cases.
|
||||
*
|
||||
* When this is used for multi-class classification, it will calculate the ROC
|
||||
* curve of each class versus the rest.
|
||||
*/
|
||||
public abstract class AbstractAucRoc implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("auc_roc");
|
||||
|
||||
protected AbstractAucRoc() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
protected static double[] percentilesArray(Percentiles percentiles) {
|
||||
double[] result = new double[99];
|
||||
percentiles.forEach(percentile -> {
|
||||
if (Double.isNaN(percentile.getValue())) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at all the percentiles values to be finite numbers", NAME.getPreferredName());
|
||||
}
|
||||
result[((int) percentile.getPercent()) - 1] = percentile.getValue();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Visible for testing
|
||||
*/
|
||||
protected static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) {
|
||||
assert tpPercentiles.length == fpPercentiles.length;
|
||||
assert tpPercentiles.length == 99;
|
||||
|
||||
List<AucRocPoint> aucRocCurve = new ArrayList<>();
|
||||
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
|
||||
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
|
||||
RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
|
||||
RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
|
||||
aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
|
||||
aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
|
||||
Collections.sort(aucRocCurve);
|
||||
return aucRocCurve;
|
||||
}
|
||||
|
||||
/**
|
||||
* Visible for testing
|
||||
*/
|
||||
protected static double calculateAucScore(List<AucRocPoint> rocCurve) {
|
||||
// Calculates AUC based on the trapezoid rule
|
||||
double aucRoc = 0.0;
|
||||
for (int i = 1; i < rocCurve.size(); i++) {
|
||||
AucRocPoint left = rocCurve.get(i - 1);
|
||||
AucRocPoint right = rocCurve.get(i);
|
||||
aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2;
|
||||
}
|
||||
return aucRoc;
|
||||
}
|
||||
|
||||
private static class RateThresholdCurve {
|
||||
|
||||
private final double[] percentiles;
|
||||
private final boolean isTp;
|
||||
|
||||
private RateThresholdCurve(double[] percentiles, boolean isTp) {
|
||||
this.percentiles = percentiles;
|
||||
this.isTp = isTp;
|
||||
}
|
||||
|
||||
private double getRate(int index) {
|
||||
return 1 - 0.01 * (index + 1);
|
||||
}
|
||||
|
||||
private double getThreshold(int index) {
|
||||
return percentiles[index];
|
||||
}
|
||||
|
||||
private double interpolateRate(double threshold) {
|
||||
int binarySearchResult = Arrays.binarySearch(percentiles, threshold);
|
||||
if (binarySearchResult >= 0) {
|
||||
return getRate(binarySearchResult);
|
||||
} else {
|
||||
int right = (binarySearchResult * -1) -1;
|
||||
int left = right - 1;
|
||||
if (right >= percentiles.length) {
|
||||
return 0.0;
|
||||
} else if (left < 0) {
|
||||
return 1.0;
|
||||
} else {
|
||||
double rightRate = getRate(right);
|
||||
double leftRate = getRate(left);
|
||||
return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<AucRocPoint> scanPoints(RateThresholdCurve againstCurve) {
|
||||
List<AucRocPoint> points = new ArrayList<>();
|
||||
for (int index = 0; index < percentiles.length; index++) {
|
||||
double rate = getRate(index);
|
||||
double scannedThreshold = getThreshold(index);
|
||||
double againstRate = againstCurve.interpolateRate(scannedThreshold);
|
||||
AucRocPoint point;
|
||||
if (isTp) {
|
||||
point = new AucRocPoint(rate, againstRate, scannedThreshold);
|
||||
} else {
|
||||
point = new AucRocPoint(againstRate, rate, scannedThreshold);
|
||||
}
|
||||
points.add(point);
|
||||
}
|
||||
return points;
|
||||
}
|
||||
}
|
||||
|
||||
public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXContentObject, Writeable {
|
||||
|
||||
private static final String TPR = "tpr";
|
||||
private static final String FPR = "fpr";
|
||||
private static final String THRESHOLD = "threshold";
|
||||
|
||||
private final double tpr;
|
||||
private final double fpr;
|
||||
private final double threshold;
|
||||
|
||||
AucRocPoint(double tpr, double fpr, double threshold) {
|
||||
this.tpr = tpr;
|
||||
this.fpr = fpr;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
private AucRocPoint(StreamInput in) throws IOException {
|
||||
this.tpr = in.readDouble();
|
||||
this.fpr = in.readDouble();
|
||||
this.threshold = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(AucRocPoint o) {
|
||||
return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed()
|
||||
.thenComparing(p -> p.fpr)
|
||||
.thenComparing(p -> p.tpr)
|
||||
.compare(this, o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(tpr);
|
||||
out.writeDouble(fpr);
|
||||
out.writeDouble(threshold);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(TPR, tpr);
|
||||
builder.field(FPR, fpr);
|
||||
builder.field(THRESHOLD, threshold);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AucRocPoint that = (AucRocPoint) o;
|
||||
return tpr == that.tpr
|
||||
&& fpr == that.fpr
|
||||
&& threshold == that.threshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(tpr, fpr, threshold);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
||||
private static double interpolate(double x, double x1, double y1, double x2, double y2) {
|
||||
return y1 + (x - x1) * (y2 - y1) / (x2 - x1);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String SCORE = "score";
|
||||
private static final String DOC_COUNT = "doc_count";
|
||||
private static final String CURVE = "curve";
|
||||
|
||||
private final double score;
|
||||
private final Long docCount;
|
||||
private final List<AucRocPoint> curve;
|
||||
|
||||
public Result(double score, Long docCount, List<AucRocPoint> curve) {
|
||||
this.score = score;
|
||||
this.docCount = docCount;
|
||||
this.curve = Objects.requireNonNull(curve);
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.score = in.readDouble();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.docCount = in.readOptionalLong();
|
||||
} else {
|
||||
this.docCount = null;
|
||||
}
|
||||
this.curve = in.readList(AucRocPoint::new);
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public Long getDocCount() {
|
||||
return docCount;
|
||||
}
|
||||
|
||||
public List<AucRocPoint> getCurve() {
|
||||
return Collections.unmodifiableList(curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(score);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeOptionalLong(docCount);
|
||||
}
|
||||
out.writeList(curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SCORE, score);
|
||||
if (docCount != null) {
|
||||
builder.field(DOC_COUNT, docCount);
|
||||
}
|
||||
if (curve.isEmpty() == false) {
|
||||
builder.field(CURVE, curve);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return score == that.score
|
||||
&& Objects.equals(docCount, that.docCount)
|
||||
&& Objects.equals(curve, that.curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(score, docCount, curve);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.collect.Tuple;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
@ -22,6 +23,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
|
|||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -33,6 +35,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
@ -95,21 +98,24 @@ public class Accuracy implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
this.actualField.trySet(fields.getActualField());
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
|
||||
if (overallAccuracy.get() == null) {
|
||||
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
|
||||
Script script = PainlessScripts.buildIsEqualScript(fields.getActualField(), fields.getPredictedField());
|
||||
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script));
|
||||
}
|
||||
if (result.get() == null) {
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
|
||||
matrix.aggs(parameters, actualField, predictedField);
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(parameters, fields);
|
||||
aggs.addAll(matrixAggs.v1());
|
||||
pipelineAggs.addAll(matrixAggs.v2());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||
import org.elasticsearch.search.aggregations.bucket.nested.Nested;
|
||||
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
|
||||
* The ROC curve is a plot of the TPR (true positive rate) against
|
||||
* the FPR (false positive rate) over a varying threshold.
|
||||
*
|
||||
* This particular implementation is making use of ES aggregations
|
||||
* to calculate the curve. It then uses the trapezoidal rule to calculate
|
||||
* the AUC.
|
||||
*
|
||||
* In particular, in order to calculate the ROC, we get percentiles of TP
|
||||
* and FP against the predicted probability. We call those Rate-Threshold
|
||||
* curves. We then scan ROC points from each Rate-Threshold curve against the
|
||||
* other using interpolation. This gives us an approximation of the ROC curve
|
||||
* that has the advantage of being efficient and resilient to some edge cases.
|
||||
*
|
||||
* When this is used for multi-class classification, it will calculate the ROC
|
||||
* curve of each class versus the rest.
|
||||
*/
|
||||
public class AucRoc extends AbstractAucRoc {
|
||||
|
||||
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
|
||||
public static final ConstructingObjectParser<AucRoc, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), a -> new AucRoc((Boolean) a[0], (String) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), CLASS_NAME);
|
||||
}
|
||||
|
||||
private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
|
||||
private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
|
||||
private static final String NESTED_AGG_NAME = "nested";
|
||||
private static final String NESTED_FILTER_AGG_NAME = "nested_filter";
|
||||
private static final String PERCENTILES_AGG_NAME = "percentiles";
|
||||
|
||||
public static AucRoc fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final boolean includeCurve;
|
||||
private final String className;
|
||||
private final SetOnce<EvaluationFields> fields = new SetOnce<>();
|
||||
private final SetOnce<EvaluationMetricResult> result = new SetOnce<>();
|
||||
|
||||
public AucRoc(Boolean includeCurve, String className) {
|
||||
this.includeCurve = includeCurve == null ? false : includeCurve;
|
||||
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME.getPreferredName());
|
||||
}
|
||||
|
||||
public AucRoc(StreamInput in) throws IOException {
|
||||
this.includeCurve = in.readBoolean();
|
||||
this.className = in.readOptionalString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeBoolean(includeCurve);
|
||||
out.writeOptionalString(className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
|
||||
if (className != null) {
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(
|
||||
EvaluationFields.ACTUAL_FIELD.getPreferredName(),
|
||||
EvaluationFields.PREDICTED_CLASS_FIELD.getPreferredName(),
|
||||
EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AucRoc that = (AucRoc) o;
|
||||
return includeCurve == that.includeCurve
|
||||
&& Objects.equals(className, that.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(includeCurve, className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
EvaluationFields fields) {
|
||||
if (result.get() != null) {
|
||||
return Tuple.tuple(Arrays.asList(), Arrays.asList());
|
||||
}
|
||||
// Store given {@code fields} for the purpose of generating error messages in {@code process}.
|
||||
this.fields.trySet(fields);
|
||||
|
||||
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
|
||||
AggregationBuilder percentilesAgg =
|
||||
AggregationBuilders
|
||||
.percentiles(PERCENTILES_AGG_NAME)
|
||||
.field(fields.getPredictedProbabilityField())
|
||||
.percentiles(percentiles);
|
||||
AggregationBuilder nestedAgg =
|
||||
AggregationBuilders
|
||||
.nested(NESTED_AGG_NAME, fields.getTopClassesField())
|
||||
.subAggregation(
|
||||
AggregationBuilders
|
||||
.filter(NESTED_FILTER_AGG_NAME, QueryBuilders.termQuery(fields.getPredictedClassField(), className))
|
||||
.subAggregation(percentilesAgg));
|
||||
QueryBuilder actualIsTrueQuery = QueryBuilders.termQuery(fields.getActualField(), className);
|
||||
AggregationBuilder percentilesForClassValueAgg =
|
||||
AggregationBuilders
|
||||
.filter(TRUE_AGG_NAME, actualIsTrueQuery)
|
||||
.subAggregation(nestedAgg);
|
||||
AggregationBuilder percentilesForRestAgg =
|
||||
AggregationBuilders
|
||||
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery))
|
||||
.subAggregation(nestedAgg);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg),
|
||||
Arrays.asList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result.get() != null) {
|
||||
return;
|
||||
}
|
||||
Filter classAgg = aggs.get(TRUE_AGG_NAME);
|
||||
Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME);
|
||||
Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
|
||||
if (classAgg.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]",
|
||||
getName(), fields.get().getActualField(), className);
|
||||
}
|
||||
if (classNestedFilter.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]",
|
||||
getName(), fields.get().getPredictedClassField(), className);
|
||||
}
|
||||
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
|
||||
double[] tpPercentiles = percentilesArray(classPercentiles);
|
||||
|
||||
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
|
||||
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
|
||||
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
|
||||
if (restAgg.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have a different value than [{}]",
|
||||
getName(), fields.get().getActualField(), className);
|
||||
}
|
||||
if (restNestedFilter.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]",
|
||||
getName(), fields.get().getPredictedClassField(), className);
|
||||
}
|
||||
Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
|
||||
double[] fpPercentiles = percentilesArray(restPercentiles);
|
||||
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
result.set(
|
||||
new Result(
|
||||
aucRocScore,
|
||||
classNestedFilter.getDocCount() + restNestedFilter.getDocCount(),
|
||||
includeCurve ? aucRocCurve : Collections.emptyList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result.get());
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -13,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -21,6 +23,9 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.TOP_CLASSES_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
|
@ -30,17 +35,22 @@ public class Classification implements Evaluation {
|
|||
|
||||
public static final ParseField NAME = new ParseField("classification");
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
private static final String DEFAULT_TOP_CLASSES_FIELD = "ml.top_classes";
|
||||
private static final String DEFAULT_PREDICTED_CLASS_FIELD_SUFFIX = ".class_name";
|
||||
private static final String DEFAULT_PREDICTED_PROBABILITY_FIELD_SUFFIX = ".class_probability";
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
public static final ConstructingObjectParser<Classification, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(),
|
||||
a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List<EvaluationMetric>) a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TOP_CLASSES_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
|
||||
}
|
||||
|
@ -50,25 +60,35 @@ public class Classification implements Evaluation {
|
|||
}
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be categorical
|
||||
* The collection of fields in the index being evaluated.
|
||||
* fields.getActualField() is assumed to be a ground truth label.
|
||||
* fields.getPredictedField() is assumed to be a predicted label.
|
||||
* fields.getPredictedClassField() and fields.getPredictedProbabilityField() are assumed to be properties under the same nested field.
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be categorical
|
||||
*/
|
||||
private final String predictedField;
|
||||
private final EvaluationFields fields;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
public Classification(String actualField,
|
||||
@Nullable String predictedField,
|
||||
@Nullable String topClassesField,
|
||||
@Nullable List<EvaluationMetric> metrics) {
|
||||
if (topClassesField == null) {
|
||||
topClassesField = DEFAULT_TOP_CLASSES_FIELD;
|
||||
}
|
||||
String predictedClassField = topClassesField + DEFAULT_PREDICTED_CLASS_FIELD_SUFFIX;
|
||||
String predictedProbabilityField = topClassesField + DEFAULT_PREDICTED_PROBABILITY_FIELD_SUFFIX;
|
||||
this.fields =
|
||||
new EvaluationFields(
|
||||
ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD),
|
||||
predictedField,
|
||||
topClassesField,
|
||||
predictedClassField,
|
||||
predictedProbabilityField,
|
||||
true);
|
||||
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
|
||||
}
|
||||
|
||||
|
@ -77,8 +97,18 @@ public class Classification implements Evaluation {
|
|||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedField = in.readString();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.fields =
|
||||
new EvaluationFields(
|
||||
in.readString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
true);
|
||||
} else {
|
||||
this.fields = new EvaluationFields(in.readString(), in.readString(), null, null, null, true);
|
||||
}
|
||||
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
|
||||
}
|
||||
|
||||
|
@ -88,13 +118,8 @@ public class Classification implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPredictedField() {
|
||||
return predictedField;
|
||||
public EvaluationFields getFields() {
|
||||
return fields;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -109,17 +134,28 @@ public class Classification implements Evaluation {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualField);
|
||||
out.writeString(predictedField);
|
||||
out.writeString(fields.getActualField());
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeOptionalString(fields.getPredictedField());
|
||||
out.writeOptionalString(fields.getTopClassesField());
|
||||
out.writeOptionalString(fields.getPredictedClassField());
|
||||
out.writeOptionalString(fields.getPredictedProbabilityField());
|
||||
} else {
|
||||
out.writeString(fields.getPredictedField());
|
||||
}
|
||||
out.writeNamedWriteableList(metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField());
|
||||
if (fields.getPredictedField() != null) {
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), fields.getPredictedField());
|
||||
}
|
||||
if (fields.getTopClassesField() != null) {
|
||||
builder.field(TOP_CLASSES_FIELD.getPreferredName(), fields.getTopClassesField());
|
||||
}
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
|
@ -135,13 +171,12 @@ public class Classification implements Evaluation {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Classification that = (Classification) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
return Objects.equals(that.fields, this.fields)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
return Objects.hash(fields, metrics);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.collect.Tuple;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
|
@ -27,6 +28,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
|||
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -39,6 +41,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static java.util.Comparator.comparing;
|
||||
|
@ -125,10 +128,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.collect.Tuple;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
@ -28,6 +29,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
|||
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -40,6 +42,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
@ -90,10 +93,16 @@ public class Precision implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
if (topActualClassNames.get() == null) { // This is step 1
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.collect.Tuple;
|
|||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
|
@ -25,6 +26,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -37,6 +39,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
@ -84,10 +87,16 @@ public class Recall implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
if (result.get() != null) {
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
|
@ -17,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -26,6 +28,7 @@ import java.io.IOException;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection.actualIsTrueQuery;
|
||||
|
||||
|
@ -66,13 +69,20 @@ abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
|
|||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(
|
||||
EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedProbabilityField) {
|
||||
EvaluationFields fields) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
String actualField = fields.getActualField();
|
||||
String predictedProbabilityField = fields.getPredictedProbabilityField();
|
||||
return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList());
|
||||
}
|
||||
|
||||
|
|
|
@ -5,14 +5,13 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
|
@ -21,20 +20,19 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
|
|||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
@ -58,30 +56,28 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetect
|
|||
* When this is used for multi-class classification, it will calculate the ROC
|
||||
* curve of each class versus the rest.
|
||||
*/
|
||||
public class AucRoc implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("auc_roc");
|
||||
public class AucRoc extends AbstractAucRoc {
|
||||
|
||||
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
|
||||
|
||||
public static final ConstructingObjectParser<AucRoc, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
|
||||
a -> new AucRoc((Boolean) a[0]));
|
||||
public static final ConstructingObjectParser<AucRoc, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), a -> new AucRoc((Boolean) a[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
|
||||
}
|
||||
|
||||
private static final String PERCENTILES = "percentiles";
|
||||
|
||||
private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
|
||||
private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
|
||||
private static final String PERCENTILES_AGG_NAME = "percentiles";
|
||||
|
||||
public static AucRoc fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final boolean includeCurve;
|
||||
private EvaluationMetricResult result;
|
||||
private final SetOnce<EvaluationFields> fields = new SetOnce<>();
|
||||
private final SetOnce<EvaluationMetricResult> result = new SetOnce<>();
|
||||
|
||||
public AucRoc(Boolean includeCurve) {
|
||||
this.includeCurve = includeCurve == null ? false : includeCurve;
|
||||
|
@ -110,8 +106,9 @@ public class AucRoc implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(
|
||||
EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -119,7 +116,7 @@ public class AucRoc implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AucRoc that = (AucRoc) o;
|
||||
return Objects.equals(includeCurve, that.includeCurve);
|
||||
return includeCurve == that.includeCurve;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -129,22 +126,29 @@ public class AucRoc implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
EvaluationFields fields) {
|
||||
if (result.get() != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
// Store given {@code fields} for the purpose of generating error messages in {@code process}.
|
||||
this.fields.trySet(fields);
|
||||
|
||||
String actualField = fields.getActualField();
|
||||
String predictedProbabilityField = fields.getPredictedProbabilityField();
|
||||
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
|
||||
AggregationBuilder percentilesAgg =
|
||||
AggregationBuilders
|
||||
.percentiles(PERCENTILES_AGG_NAME)
|
||||
.field(predictedProbabilityField)
|
||||
.percentiles(percentiles);
|
||||
AggregationBuilder percentilesForClassValueAgg =
|
||||
AggregationBuilders
|
||||
.filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField))
|
||||
.subAggregation(
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
|
||||
.subAggregation(percentilesAgg);
|
||||
AggregationBuilder percentilesForRestAgg =
|
||||
AggregationBuilders
|
||||
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
|
||||
.subAggregation(
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
|
||||
.subAggregation(percentilesAgg);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg),
|
||||
Collections.emptyList());
|
||||
|
@ -152,216 +156,33 @@ public class AucRoc implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result.get() != null) {
|
||||
return;
|
||||
}
|
||||
Filter classAgg = aggs.get(TRUE_AGG_NAME);
|
||||
if (classAgg.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have the value [{}]", getName(), fields.get().getActualField(), "true");
|
||||
}
|
||||
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES_AGG_NAME));
|
||||
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
|
||||
double[] tpPercentiles =
|
||||
percentilesArray(
|
||||
classAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getName() + "] requires at least one actual_field to have the value [true]");
|
||||
double[] fpPercentiles =
|
||||
percentilesArray(
|
||||
restAgg.getAggregations().get(PERCENTILES),
|
||||
"[" + getName() + "] requires at least one actual_field to have a different value than [true]");
|
||||
if (restAgg.getDocCount() == 0) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] requires at least one [{}] to have a different value than [{}]", getName(), fields.get().getActualField(), "true");
|
||||
}
|
||||
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES_AGG_NAME));
|
||||
|
||||
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = calculateAucScore(aucRocCurve);
|
||||
result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
|
||||
result.set(
|
||||
new Result(
|
||||
aucRocScore,
|
||||
classAgg.getDocCount() + restAgg.getDocCount(),
|
||||
includeCurve ? aucRocCurve : Collections.emptyList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
|
||||
double[] result = new double[99];
|
||||
percentiles.forEach(percentile -> {
|
||||
if (Double.isNaN(percentile.getValue())) {
|
||||
throw ExceptionsHelper.badRequestException(errorIfUndefined);
|
||||
}
|
||||
result[((int) percentile.getPercent()) - 1] = percentile.getValue();
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Visible for testing
|
||||
*/
|
||||
static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) {
|
||||
assert tpPercentiles.length == fpPercentiles.length;
|
||||
assert tpPercentiles.length == 99;
|
||||
|
||||
List<AucRocPoint> aucRocCurve = new ArrayList<>();
|
||||
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
|
||||
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
|
||||
RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
|
||||
RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
|
||||
aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
|
||||
aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
|
||||
Collections.sort(aucRocCurve);
|
||||
return aucRocCurve;
|
||||
}
|
||||
|
||||
/**
|
||||
* Visible for testing
|
||||
*/
|
||||
static double calculateAucScore(List<AucRocPoint> rocCurve) {
|
||||
// Calculates AUC based on the trapezoid rule
|
||||
double aucRoc = 0.0;
|
||||
for (int i = 1; i < rocCurve.size(); i++) {
|
||||
AucRocPoint left = rocCurve.get(i - 1);
|
||||
AucRocPoint right = rocCurve.get(i);
|
||||
aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2;
|
||||
}
|
||||
return aucRoc;
|
||||
}
|
||||
|
||||
private static class RateThresholdCurve {
|
||||
|
||||
private final double[] percentiles;
|
||||
private final boolean isTp;
|
||||
|
||||
private RateThresholdCurve(double[] percentiles, boolean isTp) {
|
||||
this.percentiles = percentiles;
|
||||
this.isTp = isTp;
|
||||
}
|
||||
|
||||
private double getRate(int index) {
|
||||
return 1 - 0.01 * (index + 1);
|
||||
}
|
||||
|
||||
private double getThreshold(int index) {
|
||||
return percentiles[index];
|
||||
}
|
||||
|
||||
private double interpolateRate(double threshold) {
|
||||
int binarySearchResult = Arrays.binarySearch(percentiles, threshold);
|
||||
if (binarySearchResult >= 0) {
|
||||
return getRate(binarySearchResult);
|
||||
} else {
|
||||
int right = (binarySearchResult * -1) -1;
|
||||
int left = right - 1;
|
||||
if (right >= percentiles.length) {
|
||||
return 0.0;
|
||||
} else if (left < 0) {
|
||||
return 1.0;
|
||||
} else {
|
||||
double rightRate = getRate(right);
|
||||
double leftRate = getRate(left);
|
||||
return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<AucRocPoint> scanPoints(RateThresholdCurve againstCurve) {
|
||||
List<AucRocPoint> points = new ArrayList<>();
|
||||
for (int index = 0; index < percentiles.length; index++) {
|
||||
double rate = getRate(index);
|
||||
double scannedThreshold = getThreshold(index);
|
||||
double againstRate = againstCurve.interpolateRate(scannedThreshold);
|
||||
AucRocPoint point;
|
||||
if (isTp) {
|
||||
point = new AucRocPoint(rate, againstRate, scannedThreshold);
|
||||
} else {
|
||||
point = new AucRocPoint(againstRate, rate, scannedThreshold);
|
||||
}
|
||||
points.add(point);
|
||||
}
|
||||
return points;
|
||||
}
|
||||
}
|
||||
|
||||
public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXContentObject, Writeable {
|
||||
double tpr;
|
||||
double fpr;
|
||||
double threshold;
|
||||
|
||||
private AucRocPoint(double tpr, double fpr, double threshold) {
|
||||
this.tpr = tpr;
|
||||
this.fpr = fpr;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
private AucRocPoint(StreamInput in) throws IOException {
|
||||
this.tpr = in.readDouble();
|
||||
this.fpr = in.readDouble();
|
||||
this.threshold = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(AucRocPoint o) {
|
||||
return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed()
|
||||
.thenComparing(p -> p.fpr)
|
||||
.thenComparing(p -> p.tpr)
|
||||
.compare(this, o);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(tpr);
|
||||
out.writeDouble(fpr);
|
||||
out.writeDouble(threshold);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field("tpr", tpr);
|
||||
builder.field("fpr", fpr);
|
||||
builder.field("threshold", threshold);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
||||
private static double interpolate(double x, double x1, double y1, double x2, double y2) {
|
||||
return y1 + (x - x1) * (y2 - y1) / (x2 - x1);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private final double score;
|
||||
private final List<AucRocPoint> curve;
|
||||
|
||||
public Result(double score, List<AucRocPoint> curve) {
|
||||
this.score = score;
|
||||
this.curve = Objects.requireNonNull(curve);
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.score = in.readDouble();
|
||||
this.curve = in.readList(AucRocPoint::new);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(OutlierDetection.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(score);
|
||||
out.writeList(curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field("score", score);
|
||||
if (curve.isEmpty() == false) {
|
||||
builder.field("curve", curve);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
return Optional.ofNullable(result.get());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -23,6 +24,8 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_PROBABILITY_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
|
@ -32,8 +35,6 @@ public class OutlierDetection implements Evaluation {
|
|||
|
||||
public static final ParseField NAME = new ParseField("outlier_detection", "binary_soft_classification");
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
public static final ConstructingObjectParser<OutlierDetection, Void> PARSER = new ConstructingObjectParser<>(
|
||||
|
@ -50,30 +51,34 @@ public class OutlierDetection implements Evaluation {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
static QueryBuilder actualIsTrueQuery(String actualField) {
|
||||
public static QueryBuilder actualIsTrueQuery(String actualField) {
|
||||
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
|
||||
}
|
||||
|
||||
/**
|
||||
* The field where the actual class is marked up.
|
||||
* The value of this field is assumed to either be 1 or 0, or true or false.
|
||||
* The collection of fields in the index being evaluated.
|
||||
* fields.getActualField() is assumed to either be 1 or 0, or true or false.
|
||||
* fields.getPredictedProbabilityField() is assumed to be a number in [0.0, 1.0].
|
||||
* Other fields are not needed by this evaluation.
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field of the predicted probability in [0.0, 1.0].
|
||||
*/
|
||||
private final String predictedProbabilityField;
|
||||
private final EvaluationFields fields;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public OutlierDetection(String actualField, String predictedProbabilityField,
|
||||
public OutlierDetection(String actualField,
|
||||
String predictedProbabilityField,
|
||||
@Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
|
||||
this.fields =
|
||||
new EvaluationFields(
|
||||
ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD),
|
||||
false);
|
||||
this.metrics = initMetrics(metrics, OutlierDetection::defaultMetrics);
|
||||
}
|
||||
|
||||
|
@ -86,8 +91,7 @@ public class OutlierDetection implements Evaluation {
|
|||
}
|
||||
|
||||
public OutlierDetection(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedProbabilityField = in.readString();
|
||||
this.fields = new EvaluationFields(in.readString(), null, null, null, in.readString(), false);
|
||||
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
|
||||
}
|
||||
|
||||
|
@ -97,13 +101,8 @@ public class OutlierDetection implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPredictedField() {
|
||||
return predictedProbabilityField;
|
||||
public EvaluationFields getFields() {
|
||||
return fields;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -118,16 +117,16 @@ public class OutlierDetection implements Evaluation {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualField);
|
||||
out.writeString(predictedProbabilityField);
|
||||
out.writeString(fields.getActualField());
|
||||
out.writeString(fields.getPredictedProbabilityField());
|
||||
out.writeNamedWriteableList(metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField());
|
||||
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), fields.getPredictedProbabilityField());
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
|
@ -144,13 +143,12 @@ public class OutlierDetection implements Evaluation {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
OutlierDetection that = (OutlierDetection) o;
|
||||
return Objects.equals(actualField, that.actualField)
|
||||
&& Objects.equals(predictedProbabilityField, that.predictedProbabilityField)
|
||||
return Objects.equals(fields, that.fields)
|
||||
&& Objects.equals(metrics, that.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedProbabilityField, metrics);
|
||||
return Objects.hash(fields, metrics);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
|
|||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -31,6 +33,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
@ -86,13 +89,19 @@ public class Huber implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, delta * delta)))),
|
||||
Collections.emptyList());
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -19,6 +20,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
|
|||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -31,6 +33,7 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
|
@ -70,13 +73,19 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))),
|
||||
Collections.emptyList());
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
|
|||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -31,6 +33,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
@ -85,13 +88,19 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, offset)))),
|
||||
Collections.emptyList());
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
|
|||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
|
||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -32,6 +34,7 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
|
@ -74,13 +77,19 @@ public class RSquared implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
String actualField = fields.getActualField();
|
||||
String predictedField = fields.getPredictedField();
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
|
@ -21,6 +22,8 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_FIELD;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
|
@ -30,8 +33,6 @@ public class Regression implements Evaluation {
|
|||
|
||||
public static final ParseField NAME = new ParseField("regression");
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -50,16 +51,12 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
* The collection of fields in the index being evaluated.
|
||||
* fields.getActualField() is assumed to be numeric.
|
||||
* fields.getPredictedField() is assumed to be numeric.
|
||||
* Other fields are not needed by this evaluation.
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String predictedField;
|
||||
private final EvaluationFields fields;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
|
@ -67,8 +64,14 @@ public class Regression implements Evaluation {
|
|||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.fields =
|
||||
new EvaluationFields(
|
||||
ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD),
|
||||
ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
false);
|
||||
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
|
||||
}
|
||||
|
||||
|
@ -77,8 +80,7 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedField = in.readString();
|
||||
this.fields = new EvaluationFields(in.readString(), in.readString(), null, null, null, false);
|
||||
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
|
||||
}
|
||||
|
||||
|
@ -88,13 +90,8 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getActualField() {
|
||||
return actualField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getPredictedField() {
|
||||
return predictedField;
|
||||
public EvaluationFields getFields() {
|
||||
return fields;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -109,16 +106,16 @@ public class Regression implements Evaluation {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualField);
|
||||
out.writeString(predictedField);
|
||||
out.writeString(fields.getActualField());
|
||||
out.writeString(fields.getPredictedField());
|
||||
out.writeNamedWriteableList(metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField());
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), fields.getPredictedField());
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
|
@ -135,13 +132,12 @@ public class Regression implements Evaluation {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Regression that = (Regression) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
return Objects.equals(that.fields, this.fields)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
return Objects.hash(fields, metrics);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,13 +11,14 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRocResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
@ -25,6 +26,10 @@ import java.util.List;
|
|||
|
||||
public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
private static final String OUTLIER_DETECTION = "outlier_detection";
|
||||
private static final String CLASSIFICATION = "classification";
|
||||
private static final String REGRESSION = "regression";
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
|
@ -32,18 +37,35 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
|
|||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
String evaluationName = randomAlphaOfLength(10);
|
||||
List<EvaluationMetricResult> metrics =
|
||||
Arrays.asList(
|
||||
AccuracyResultTests.createRandom(),
|
||||
PrecisionResultTests.createRandom(),
|
||||
RecallResultTests.createRandom(),
|
||||
MulticlassConfusionMatrixResultTests.createRandom(),
|
||||
new MeanSquaredError.Result(randomDouble()),
|
||||
new MeanSquaredLogarithmicError.Result(randomDouble()),
|
||||
new Huber.Result(randomDouble()),
|
||||
new RSquared.Result(randomDouble()));
|
||||
return new Response(evaluationName, randomSubsetOf(metrics));
|
||||
String evaluationName = randomFrom(OUTLIER_DETECTION, CLASSIFICATION, REGRESSION);
|
||||
List<EvaluationMetricResult> metrics;
|
||||
switch (evaluationName) {
|
||||
case OUTLIER_DETECTION:
|
||||
metrics = randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AucRocResultTests.createRandom()));
|
||||
break;
|
||||
case CLASSIFICATION:
|
||||
metrics = randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AucRocResultTests.createRandom(),
|
||||
AccuracyResultTests.createRandom(),
|
||||
PrecisionResultTests.createRandom(),
|
||||
RecallResultTests.createRandom(),
|
||||
MulticlassConfusionMatrixResultTests.createRandom()));
|
||||
break;
|
||||
case REGRESSION:
|
||||
metrics = randomSubsetOf(
|
||||
Arrays.asList(
|
||||
new MeanSquaredError.Result(randomDouble()),
|
||||
new MeanSquaredLogarithmicError.Result(randomDouble()),
|
||||
new Huber.Result(randomDouble()),
|
||||
new RSquared.Result(randomDouble())));
|
||||
break;
|
||||
default:
|
||||
throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement");
|
||||
}
|
||||
return new Response(evaluationName, metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -44,7 +44,6 @@ import java.util.Set;
|
|||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -366,15 +365,27 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
assertThat(
|
||||
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
||||
equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING)));
|
||||
Map<String, Object> expectedTopClassesMapping = new HashMap<String, Object>() {{
|
||||
put("type", "nested");
|
||||
put("properties", new HashMap<String, Object>() {{
|
||||
put("class_name", Collections.singletonMap("bar", "baz"));
|
||||
put("class_probability", Collections.singletonMap("type", "double"));
|
||||
}});
|
||||
}};
|
||||
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
||||
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
||||
"results");
|
||||
assertThat(explicitlyMappedFields,
|
||||
allOf(
|
||||
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
||||
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")));
|
||||
assertThat(explicitlyMappedFields, hasEntry("results.top_classes", expectedTopClassesMapping));
|
||||
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING));
|
||||
|
||||
expectedTopClassesMapping = new HashMap<String, Object>() {{
|
||||
put("type", "nested");
|
||||
put("properties", new HashMap<String, Object>() {{
|
||||
put("class_name", Collections.singletonMap("type", "long"));
|
||||
put("class_probability", Collections.singletonMap("type", "double"));
|
||||
}});
|
||||
}};
|
||||
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
||||
new HashMap<String, Object>() {{
|
||||
put("foo", new HashMap<String, String>() {{
|
||||
|
@ -384,10 +395,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
put("bar", Collections.singletonMap("type", "long"));
|
||||
}},
|
||||
"results");
|
||||
assertThat(explicitlyMappedFields,
|
||||
allOf(
|
||||
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
||||
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")));
|
||||
assertThat(explicitlyMappedFields, hasEntry("results.top_classes", expectedTopClassesMapping));
|
||||
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING));
|
||||
|
||||
assertThat(
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class EvaluationFieldsTests extends ESTestCase {
|
||||
|
||||
public void testConstructorAndGetters() {
|
||||
EvaluationFields fields = new EvaluationFields("a", "b", "c", "d", "e", true);
|
||||
assertThat(fields.getActualField(), is(equalTo("a")));
|
||||
assertThat(fields.getPredictedField(), is(equalTo("b")));
|
||||
assertThat(fields.getTopClassesField(), is(equalTo("c")));
|
||||
assertThat(fields.getPredictedClassField(), is(equalTo("d")));
|
||||
assertThat(fields.getPredictedProbabilityField(), is(equalTo("e")));
|
||||
assertThat(fields.isPredictedProbabilityFieldNested(), is(true));
|
||||
}
|
||||
|
||||
public void testConstructorAndGetters_WithNullValues() {
|
||||
EvaluationFields fields = new EvaluationFields("a", null, "c", null, "e", true);
|
||||
assertThat(fields.getActualField(), is(equalTo("a")));
|
||||
assertThat(fields.getPredictedField(), is(nullValue()));
|
||||
assertThat(fields.getTopClassesField(), is(equalTo("c")));
|
||||
assertThat(fields.getPredictedClassField(), is(nullValue()));
|
||||
assertThat(fields.getPredictedProbabilityField(), is(equalTo("e")));
|
||||
assertThat(fields.isPredictedProbabilityFieldNested(), is(true));
|
||||
}
|
||||
|
||||
public void testListPotentiallyRequiredFields() {
|
||||
EvaluationFields fields = new EvaluationFields("a", "b", "c", "d", "e", randomBoolean());
|
||||
assertThat(fields.listPotentiallyRequiredFields().stream().map(Tuple::v2).collect(toList()), contains("a", "b", "c", "d", "e"));
|
||||
}
|
||||
|
||||
public void testListPotentiallyRequiredFields_WithNullValues() {
|
||||
EvaluationFields fields = new EvaluationFields("a", null, "c", null, "e", randomBoolean());
|
||||
assertThat(fields.listPotentiallyRequiredFields().stream().map(Tuple::v2).collect(toList()), contains("a", null, "c", null, "e"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
|
||||
public class AbstractAucRocTests extends ESTestCase {
|
||||
|
||||
public void testCalculateAucScore_GivenZeroPercentiles() {
|
||||
double[] tpPercentiles = zeroPercentiles();
|
||||
double[] fpPercentiles = zeroPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(0.5, 0.01));
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() {
|
||||
double[] tpPercentiles = randomPercentiles();
|
||||
double[] fpPercentiles = zeroPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(1.0, 0.1));
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() {
|
||||
double[] tpPercentiles = zeroPercentiles();
|
||||
double[] fpPercentiles = randomPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(0.0, 0.1));
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenRandomPercentiles() {
|
||||
for (int i = 0; i < 20; i++) {
|
||||
double[] tpPercentiles = randomPercentiles();
|
||||
double[] fpPercentiles = randomPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
|
||||
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
|
||||
|
||||
assertThat(aucRocScore, greaterThanOrEqualTo(0.0));
|
||||
assertThat(aucRocScore, lessThanOrEqualTo(1.0));
|
||||
assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0));
|
||||
assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0));
|
||||
assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05));
|
||||
}
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenPrecalculated() {
|
||||
double[] tpPercentiles = new double[99];
|
||||
double[] fpPercentiles = new double[99];
|
||||
|
||||
double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 };
|
||||
double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 };
|
||||
|
||||
for (int i = 0; i < tpPercentiles.length; i++) {
|
||||
int simplifiedIndex = i / 25;
|
||||
tpPercentiles[i] = tpSimplified[simplifiedIndex];
|
||||
fpPercentiles[i] = fpSimplified[simplifiedIndex];
|
||||
}
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
|
||||
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(0.8, 0.05));
|
||||
assertThat(inverseAucRocScore, closeTo(0.2, 0.05));
|
||||
}
|
||||
|
||||
public static double[] zeroPercentiles() {
|
||||
double[] percentiles = new double[99];
|
||||
Arrays.fill(percentiles, 0.0);
|
||||
return percentiles;
|
||||
}
|
||||
|
||||
public static double[] randomPercentiles() {
|
||||
double[] percentiles = new double[99];
|
||||
for (int i = 0; i < percentiles.length; i++) {
|
||||
percentiles[i] = randomDouble();
|
||||
}
|
||||
Arrays.sort(percentiles);
|
||||
return percentiles;
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||
|
@ -32,6 +33,7 @@ import static org.hamcrest.Matchers.equalTo;
|
|||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
|
||||
|
||||
@Override
|
||||
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -88,7 +90,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.process(aggs);
|
||||
|
||||
assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(accuracy.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
|
||||
|
||||
Result result = accuracy.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
|
@ -130,7 +132,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||
|
||||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||
accuracy.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.AucRocPoint;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.Result;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class AucRocResultTests extends AbstractWireSerializingTestCase<Result> {
|
||||
|
||||
public static Result createRandom() {
|
||||
double score = randomDoubleBetween(0.0, 1.0, true);
|
||||
Long docCount = randomBoolean() ? randomLong() : null;
|
||||
List<AucRocPoint> curve =
|
||||
Stream
|
||||
.generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble()))
|
||||
.limit(randomIntBetween(0, 20))
|
||||
.collect(Collectors.toList());
|
||||
return new Result(score, docCount, curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Result> instanceReader() {
|
||||
return Result::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class AucRocTests extends AbstractSerializingTestCase<AucRoc> {
|
||||
|
||||
@Override
|
||||
protected AucRoc doParseInstance(XContentParser parser) throws IOException {
|
||||
return AucRoc.PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AucRoc createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<AucRoc> instanceReader() {
|
||||
return AucRoc::new;
|
||||
}
|
||||
|
||||
public static AucRoc createRandom() {
|
||||
return new AucRoc(randomBoolean() ? randomBoolean() : null, randomAlphaOfLength(randomIntBetween(2, 10)));
|
||||
}
|
||||
}
|
|
@ -6,12 +6,14 @@
|
|||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.join.ScoreMode;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -23,6 +25,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
|
|||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
@ -33,6 +36,7 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
|
||||
|
@ -61,10 +65,17 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AccuracyTests.createRandom(),
|
||||
AucRocTests.createRandom(),
|
||||
PrecisionTests.createRandom(),
|
||||
RecallTests.createRandom(),
|
||||
MulticlassConfusionMatrixTests.createRandom()));
|
||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
boolean usesAucRoc = metrics.stream().map(EvaluationMetric::getName).anyMatch(n -> AucRoc.NAME.getPreferredName().equals(n));
|
||||
return new Classification(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
// If AucRoc is to be calculated, the top_classes field is required
|
||||
(usesAucRoc || randomBoolean()) ? randomAlphaOfLength(10) : null,
|
||||
metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -82,13 +93,35 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
return Classification::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMissingField() {
|
||||
FakeClassificationMetric metric = new FakeClassificationMetric("fake");
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", null, null, Collections.singletonList(metric)));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
is(equalTo("[classification] must define [predicted_field] as required by the following metrics [fake]")));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEmptyMetrics() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", "bar", Collections.emptyList()));
|
||||
() -> new Classification("foo", "bar", "results", Collections.emptyList()));
|
||||
assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics"));
|
||||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
public void testGetFields() {
|
||||
Classification evaluation = new Classification("foo", "bar", "results", null);
|
||||
EvaluationFields fields = evaluation.getFields();
|
||||
assertThat(fields.getActualField(), is(equalTo("foo")));
|
||||
assertThat(fields.getPredictedField(), is(equalTo("bar")));
|
||||
assertThat(fields.getTopClassesField(), is(equalTo("results")));
|
||||
assertThat(fields.getPredictedClassField(), is(equalTo("results.class_name")));
|
||||
assertThat(fields.getPredictedProbabilityField(), is(equalTo("results.class_probability")));
|
||||
assertThat(fields.isPredictedProbabilityFieldNested(), is(true));
|
||||
}
|
||||
|
||||
public void testBuildSearch_WithDefaultNonRequiredNestedFields() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
|
@ -101,7 +134,78 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
|
||||
Classification evaluation = new Classification("act", "pred", null, Arrays.asList(new MulticlassConfusionMatrix()));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
||||
public void testBuildSearch_WithExplicitNonRequiredNestedFields() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
|
||||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(QueryBuilders.existsQuery("pred"))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", "results", Arrays.asList(new MulticlassConfusionMatrix()));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
||||
public void testBuildSearch_WithDefaultRequiredNestedFields() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
|
||||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(
|
||||
QueryBuilders.nestedQuery("ml.top_classes", QueryBuilders.existsQuery("ml.top_classes.class_name"), ScoreMode.None)
|
||||
.ignoreUnmapped(true))
|
||||
.filter(
|
||||
QueryBuilders.nestedQuery(
|
||||
"ml.top_classes", QueryBuilders.existsQuery("ml.top_classes.class_probability"), ScoreMode.None)
|
||||
.ignoreUnmapped(true))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", null, Arrays.asList(new AucRoc(false, "some-value")));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
||||
public void testBuildSearch_WithExplicitRequiredNestedFields() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
|
||||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(
|
||||
QueryBuilders.nestedQuery("results", QueryBuilders.existsQuery("results.class_name"), ScoreMode.None)
|
||||
.ignoreUnmapped(true))
|
||||
.filter(
|
||||
QueryBuilders.nestedQuery("results", QueryBuilders.existsQuery("results.class_probability"), ScoreMode.None)
|
||||
.ignoreUnmapped(true))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", "results", Arrays.asList(new AucRoc(false, "some-value")));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
|
@ -114,7 +218,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
|
||||
EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
|
||||
Classification evaluation = new Classification("act", "pred", null, Arrays.asList(metric1, metric2, metric3, metric4));
|
||||
assertThat(metric1.getResult(), isEmpty());
|
||||
assertThat(metric2.getResult(), isEmpty());
|
||||
assertThat(metric3.getResult(), isEmpty());
|
||||
|
@ -183,6 +287,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
private int currentStepIndex;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
FakeClassificationMetric(String name) {
|
||||
this(name, 1);
|
||||
}
|
||||
|
||||
FakeClassificationMetric(String name, int numSteps) {
|
||||
this.name = name;
|
||||
this.numSteps = numSteps;
|
||||
|
@ -198,10 +306,14 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<String> getRequiredFields() {
|
||||
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
EvaluationFields fields) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
|
@ -37,6 +38,7 @@ import static org.hamcrest.Matchers.not;
|
|||
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -83,7 +85,8 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
|
||||
public void testAggs() {
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred");
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
|
||||
confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
|
||||
assertThat(aggs, isTuple(not(empty()), empty()));
|
||||
assertThat(confusionMatrix.getResult(), isEmpty());
|
||||
}
|
||||
|
@ -119,7 +122,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -162,7 +165,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -246,7 +249,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
confusionMatrix.process(aggsStep2);
|
||||
confusionMatrix.process(aggsStep3);
|
||||
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -28,6 +29,7 @@ import static org.hamcrest.Matchers.equalTo;
|
|||
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
|
||||
|
||||
@Override
|
||||
protected Precision doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -64,7 +66,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
|||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
|
||||
assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(precision.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
|
||||
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
|
@ -114,7 +116,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
|||
Aggregations aggs =
|
||||
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
|
||||
Precision precision = new Precision();
|
||||
precision.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||
precision.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
|
|||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -27,6 +28,7 @@ import static org.hamcrest.Matchers.equalTo;
|
|||
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||
|
||||
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
|
||||
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
|
||||
|
||||
@Override
|
||||
protected Recall doParseInstance(XContentParser parser) throws IOException {
|
||||
|
@ -62,7 +64,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
|||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
|
||||
assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(recall.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
|
||||
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
|
@ -113,7 +115,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
|||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
|
||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
|
||||
Recall recall = new Recall();
|
||||
recall.aggs(EVALUATION_PARAMETERS, "foo", "bar");
|
||||
recall.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
|
|
@ -10,12 +10,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
|
||||
public class AucRocTests extends AbstractSerializingTestCase<AucRoc> {
|
||||
|
||||
|
@ -37,91 +31,4 @@ public class AucRocTests extends AbstractSerializingTestCase<AucRoc> {
|
|||
public static AucRoc createRandom() {
|
||||
return new AucRoc(randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenZeroPercentiles() {
|
||||
double[] tpPercentiles = zeroPercentiles();
|
||||
double[] fpPercentiles = zeroPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(0.5, 0.01));
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() {
|
||||
double[] tpPercentiles = randomPercentiles();
|
||||
double[] fpPercentiles = zeroPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(1.0, 0.1));
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() {
|
||||
double[] tpPercentiles = zeroPercentiles();
|
||||
double[] fpPercentiles = randomPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(0.0, 0.1));
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenRandomPercentiles() {
|
||||
for (int i = 0; i < 20; i++) {
|
||||
double[] tpPercentiles = randomPercentiles();
|
||||
double[] fpPercentiles = randomPercentiles();
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
|
||||
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
|
||||
|
||||
assertThat(aucRocScore, greaterThanOrEqualTo(0.0));
|
||||
assertThat(aucRocScore, lessThanOrEqualTo(1.0));
|
||||
assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0));
|
||||
assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0));
|
||||
assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05));
|
||||
}
|
||||
}
|
||||
|
||||
public void testCalculateAucScore_GivenPrecalculated() {
|
||||
double[] tpPercentiles = new double[99];
|
||||
double[] fpPercentiles = new double[99];
|
||||
|
||||
double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 };
|
||||
double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 };
|
||||
|
||||
for (int i = 0; i < tpPercentiles.length; i++) {
|
||||
int simplifiedIndex = i / 25;
|
||||
tpPercentiles[i] = tpSimplified[simplifiedIndex];
|
||||
fpPercentiles[i] = fpSimplified[simplifiedIndex];
|
||||
}
|
||||
|
||||
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
|
||||
double aucRocScore = AucRoc.calculateAucScore(curve);
|
||||
|
||||
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
|
||||
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
|
||||
|
||||
assertThat(aucRocScore, closeTo(0.8, 0.05));
|
||||
assertThat(inverseAucRocScore, closeTo(0.2, 0.05));
|
||||
}
|
||||
|
||||
public static double[] zeroPercentiles() {
|
||||
double[] percentiles = new double[99];
|
||||
Arrays.fill(percentiles, 0.0);
|
||||
return percentiles;
|
||||
}
|
||||
|
||||
public static double[] randomPercentiles() {
|
||||
double[] percentiles = new double[99];
|
||||
for (int i = 0; i < percentiles.length; i++) {
|
||||
percentiles[i] = randomDouble();
|
||||
}
|
||||
Arrays.sort(percentiles);
|
||||
return percentiles;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
|
|||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
@ -26,6 +27,8 @@ import java.util.List;
|
|||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
|
||||
|
||||
|
@ -86,6 +89,17 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics"));
|
||||
}
|
||||
|
||||
public void testGetFields() {
|
||||
OutlierDetection evaluation = new OutlierDetection("foo", "bar", null);
|
||||
EvaluationFields fields = evaluation.getFields();
|
||||
assertThat(fields.getActualField(), is(equalTo("foo")));
|
||||
assertThat(fields.getPredictedField(), is(nullValue()));
|
||||
assertThat(fields.getTopClassesField(), is(nullValue()));
|
||||
assertThat(fields.getPredictedClassField(), is(nullValue()));
|
||||
assertThat(fields.getPredictedProbabilityField(), is(equalTo("bar")));
|
||||
assertThat(fields.isPredictedProbabilityFieldNested(), is(false));
|
||||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
|
|||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
@ -26,6 +27,8 @@ import java.util.List;
|
|||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
|
@ -73,6 +76,17 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
|
||||
}
|
||||
|
||||
public void testGetFields() {
|
||||
Regression evaluation = new Regression("foo", "bar", null);
|
||||
EvaluationFields fields = evaluation.getFields();
|
||||
assertThat(fields.getActualField(), is(equalTo("foo")));
|
||||
assertThat(fields.getPredictedField(), is(equalTo("bar")));
|
||||
assertThat(fields.getTopClassesField(), is(nullValue()));
|
||||
assertThat(fields.getPredictedClassField(), is(nullValue()));
|
||||
assertThat(fields.getPredictedProbabilityField(), is(nullValue()));
|
||||
assertThat(fields.isPredictedProbabilityFieldNested(), is(false));
|
||||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
|
|
|
@ -115,6 +115,11 @@ yamlRestTest {
|
|||
'ml/evaluate_data_frame/Test classification given evaluation with empty metrics',
|
||||
'ml/evaluate_data_frame/Test classification given missing actual_field',
|
||||
'ml/evaluate_data_frame/Test classification given missing predicted_field',
|
||||
'ml/evaluate_data_frame/Test classification given missing top_classes_field',
|
||||
'ml/evaluate_data_frame/Test classification auc_roc given actual_field is never equal to fish',
|
||||
'ml/evaluate_data_frame/Test classification auc_roc given predicted_class_field is never equal to mouse',
|
||||
'ml/evaluate_data_frame/Test classification auc_roc with missing class_name',
|
||||
'ml/evaluate_data_frame/Test classification accuracy with missing predicted_field',
|
||||
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
|
||||
'ml/evaluate_data_frame/Test regression given missing actual_field',
|
||||
'ml/evaluate_data_frame/Test regression given missing predicted_field',
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.search.aggregations.MultiBucketConsumerService.TooManyB
|
|||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
|
@ -24,13 +25,17 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
@ -38,16 +43,19 @@ import static org.hamcrest.Matchers.notANumber;
|
|||
|
||||
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
|
||||
static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
|
||||
|
||||
private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
|
||||
private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction";
|
||||
private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
|
||||
private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
|
||||
private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction";
|
||||
private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
|
||||
private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
|
||||
private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction";
|
||||
static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
|
||||
static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction";
|
||||
static final String ANIMAL_NAME_PREDICTION_PROB_FIELD = "animal_name_prediction_prob";
|
||||
static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
|
||||
static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
|
||||
static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction";
|
||||
static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
|
||||
static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
|
||||
static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction";
|
||||
static final String IS_PREDATOR_PREDICTION_PROBABILITY_FIELD = "predator_prediction_probability";
|
||||
static final String ML_TOP_CLASSES_FIELD = "ml_results";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
|
@ -67,7 +75,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
public void testEvaluate_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null));
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
|
@ -82,6 +91,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
|
||||
null,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
|
@ -116,6 +126,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
new Classification(
|
||||
actualField,
|
||||
predictedField,
|
||||
null,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
|
@ -139,9 +150,37 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(recallResult.getAvgRecall(), equalTo(0.0));
|
||||
}
|
||||
|
||||
private AucRoc.Result evaluateAucRoc(boolean includeCurve) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, null, ML_TOP_CLASSES_FIELD, Arrays.asList(new AucRoc(includeCurve, "cat"))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
AucRoc.Result aucrocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(aucrocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
|
||||
return aucrocResult;
|
||||
}
|
||||
|
||||
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
|
||||
AucRoc.Result aucrocResult = evaluateAucRoc(false);
|
||||
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
|
||||
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
|
||||
assertThat(aucrocResult.getCurve(), hasSize(0));
|
||||
}
|
||||
|
||||
public void testEvaluate_AucRoc_IncludeCurve() {
|
||||
AucRoc.Result aucrocResult = evaluateAucRoc(true);
|
||||
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
|
||||
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
|
||||
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy())));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -260,7 +299,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
|
||||
private Precision.Result evaluatePrecision(String actualField, String predictedField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision())));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -354,13 +393,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision()))));
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, Arrays.asList(new Precision()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
|
||||
}
|
||||
|
||||
private Recall.Result evaluateRecall(String actualField, String predictedField) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall())));
|
||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -469,7 +508,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall()))));
|
||||
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, Arrays.asList(new Recall()))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
|
||||
}
|
||||
|
||||
|
@ -478,7 +517,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
ANIMAL_NAME_KEYWORD_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
|
||||
null,
|
||||
Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
@ -561,6 +603,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
new Classification(
|
||||
ANIMAL_NAME_KEYWORD_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
|
||||
null,
|
||||
Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
|
@ -595,7 +638,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
|
||||
}
|
||||
|
||||
private static void createAnimalsIndex(String indexName) {
|
||||
static void createAnimalsIndex(String indexName) {
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc",
|
||||
ANIMAL_NAME_KEYWORD_FIELD, "type=keyword",
|
||||
|
@ -605,28 +648,41 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer",
|
||||
IS_PREDATOR_KEYWORD_FIELD, "type=keyword",
|
||||
IS_PREDATOR_BOOLEAN_FIELD, "type=boolean",
|
||||
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean")
|
||||
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean",
|
||||
IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, "type=double",
|
||||
ML_TOP_CLASSES_FIELD, "type=nested")
|
||||
.get();
|
||||
}
|
||||
|
||||
private static void indexAnimalsData(String indexName) {
|
||||
static void indexAnimalsData(String indexName) {
|
||||
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < animalNames.size(); i++) {
|
||||
for (int j = 0; j < animalNames.size(); j++) {
|
||||
for (int k = 0; k < j + 1; k++) {
|
||||
List<?> topClasses =
|
||||
IntStream
|
||||
.range(0, 5)
|
||||
.mapToObj(ix -> new HashMap<String, Object>() {{
|
||||
put("class_name", animalNames.get(ix));
|
||||
put("class_probability", 0.4 - 0.1 * ix);
|
||||
}})
|
||||
.collect(toList());
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName)
|
||||
.source(
|
||||
ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i),
|
||||
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()),
|
||||
ANIMAL_NAME_PREDICTION_PROB_FIELD, animalNames.get((i + j) % animalNames.size()),
|
||||
NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1),
|
||||
NO_LEGS_INTEGER_FIELD, i + 1,
|
||||
NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1,
|
||||
IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0),
|
||||
IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0,
|
||||
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0));
|
||||
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0,
|
||||
IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, i % 2 == 0 ? 1.0 - 0.1 * i : 0.1 * i,
|
||||
ML_TOP_CLASSES_FIELD, topClasses));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
|
@ -957,9 +958,15 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
|
||||
dependentVariable,
|
||||
predictedClassField,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
null,
|
||||
Arrays.asList(
|
||||
new Accuracy(),
|
||||
new AucRoc(true, dependentVariableValues.get(0).toString()),
|
||||
new MulticlassConfusionMatrix(),
|
||||
new Precision(),
|
||||
new Recall())));
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(5));
|
||||
|
||||
{ // Accuracy
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
|
@ -970,9 +977,17 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
{ // AucRoc
|
||||
AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1);
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
|
||||
assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
|
||||
assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L)));
|
||||
assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
{ // MulticlassConfusionMatrix
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1);
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(2);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
List<MulticlassConfusionMatrix.ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
|
||||
assertThat(
|
||||
|
@ -990,7 +1005,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
}
|
||||
|
||||
{ // Precision
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(3);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
for (Precision.PerClassResult klass : precisionResult.getClasses()) {
|
||||
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
|
||||
|
@ -999,7 +1014,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
}
|
||||
|
||||
{ // Recall
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(4);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
for (Recall.PerClassResult klass : recallResult.getClasses()) {
|
||||
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.ANIMALS_DATA_INDEX;
|
||||
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.IS_PREDATOR_BOOLEAN_FIELD;
|
||||
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.IS_PREDATOR_PREDICTION_PROBABILITY_FIELD;
|
||||
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.createAnimalsIndex;
|
||||
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.indexAnimalsData;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class OutlierDetectionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
createAnimalsIndex(ANIMALS_DATA_INDEX);
|
||||
indexAnimalsData(ANIMALS_DATA_INDEX);
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void testEvaluate_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new OutlierDetection(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||
containsInAnyOrder(
|
||||
AucRoc.NAME.getPreferredName(),
|
||||
Precision.NAME.getPreferredName(),
|
||||
Recall.NAME.getPreferredName(),
|
||||
ConfusionMatrix.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_AllMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new OutlierDetection(
|
||||
IS_PREDATOR_BOOLEAN_FIELD,
|
||||
IS_PREDATOR_PREDICTION_PROBABILITY_FIELD,
|
||||
Arrays.asList(
|
||||
new AucRoc(false),
|
||||
new Precision(Arrays.asList(0.5)),
|
||||
new Recall(Arrays.asList(0.5)),
|
||||
new ConfusionMatrix(Arrays.asList(0.5)))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||
containsInAnyOrder(
|
||||
AucRoc.NAME.getPreferredName(),
|
||||
Precision.NAME.getPreferredName(),
|
||||
Recall.NAME.getPreferredName(),
|
||||
ConfusionMatrix.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
private AucRoc.Result evaluateAucRoc(String actualField, String predictedField, boolean includeCurve) {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new OutlierDetection(actualField, predictedField, Arrays.asList(new AucRoc(includeCurve))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
AucRoc.Result aucrocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(aucrocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
|
||||
return aucrocResult;
|
||||
}
|
||||
|
||||
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
|
||||
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false);
|
||||
assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001)));
|
||||
assertThat(aucrocResult.getCurve(), hasSize(0));
|
||||
}
|
||||
|
||||
public void testEvaluate_AucRoc_IncludeCurve() {
|
||||
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true);
|
||||
assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001)));
|
||||
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
@Override
|
||||
boolean supportsInference() {
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -206,25 +206,25 @@ public class DestinationIndexTests extends ESTestCase {
|
|||
public void testCreateDestinationIndex_Classification() throws IOException {
|
||||
Map<String, Object> map = testCreateDestinationIndex(new Classification(NUMERICAL_FIELD));
|
||||
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testCreateDestinationIndex_Classification_DependentVariableIsNested() throws IOException {
|
||||
Map<String, Object> map = testCreateDestinationIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD));
|
||||
assertThat(extractValue("_doc.properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testCreateDestinationIndex_Classification_DependentVariableIsAlias() throws IOException {
|
||||
Map<String, Object> map = testCreateDestinationIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD));
|
||||
assertThat(extractValue("_doc.properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testCreateDestinationIndex_Classification_DependentVariableIsAliasToNested() throws IOException {
|
||||
Map<String, Object> map = testCreateDestinationIndex(new Classification(ALIAS_TO_NESTED_FIELD));
|
||||
assertThat(extractValue("_doc.properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testCreateDestinationIndex_ResultsFieldsExistsInSourceIndex() throws IOException {
|
||||
|
@ -322,25 +322,25 @@ public class DestinationIndexTests extends ESTestCase {
|
|||
public void testUpdateMappingsToDestIndex_Classification() throws IOException {
|
||||
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(NUMERICAL_FIELD));
|
||||
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsNested() throws IOException {
|
||||
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD));
|
||||
assertThat(extractValue("properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAlias() throws IOException {
|
||||
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD));
|
||||
assertThat(extractValue("properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAliasToNested() throws IOException {
|
||||
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NESTED_FIELD));
|
||||
assertThat(extractValue("properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
|
||||
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
|
||||
}
|
||||
|
||||
public void testUpdateMappingsToDestIndex_ResultsFieldsExistsInSourceIndex() throws IOException {
|
||||
|
|
|
@ -1,5 +1,14 @@
|
|||
setup:
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: utopia
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
ml.top_classes:
|
||||
type: nested
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: utopia
|
||||
|
@ -14,7 +23,11 @@ setup:
|
|||
"classification_field_act": "dog",
|
||||
"classification_field_pred": "dog",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "dog", "class_probability": 0.9},
|
||||
{"class_name": "cat", "class_probability": 0.1}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -31,7 +44,11 @@ setup:
|
|||
"classification_field_act": "cat",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "cat", "class_probability": 0.8},
|
||||
{"class_name": "dog", "class_probability": 0.2}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -48,7 +65,11 @@ setup:
|
|||
"classification_field_act": "mouse",
|
||||
"classification_field_pred": "mouse",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "cat", "class_probability": 0.3},
|
||||
{"class_name": "dog", "class_probability": 0.1}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -65,7 +86,11 @@ setup:
|
|||
"classification_field_act": "dog",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "cat", "class_probability": 0.6},
|
||||
{"class_name": "dog", "class_probability": 0.3}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -82,7 +107,11 @@ setup:
|
|||
"classification_field_act": "cat",
|
||||
"classification_field_pred": "dog",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "dog", "class_probability": 0.7},
|
||||
{"class_name": "cat", "class_probability": 0.3}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -99,7 +128,11 @@ setup:
|
|||
"classification_field_act": "dog",
|
||||
"classification_field_pred": "dog",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "dog", "class_probability": 0.9},
|
||||
{"class_name": "cat", "class_probability": 0.1}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -116,7 +149,11 @@ setup:
|
|||
"classification_field_act": "cat",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "cat", "class_probability": 0.8},
|
||||
{"class_name": "dog", "class_probability": 0.2}
|
||||
]
|
||||
}
|
||||
|
||||
- do:
|
||||
|
@ -133,7 +170,11 @@ setup:
|
|||
"classification_field_act": "mouse",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
"all_false_field": false,
|
||||
"ml.top_classes": [
|
||||
{"class_name": "cat", "class_probability": 0.8},
|
||||
{"class_name": "dog", "class_probability": 0.2}
|
||||
]
|
||||
}
|
||||
|
||||
# This document misses the required fields and should be ignored
|
||||
|
@ -166,6 +207,7 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { outlier_detection.auc_roc.score: 0.9899 }
|
||||
- match: { outlier_detection.auc_roc.doc_count: 8 }
|
||||
- is_false: outlier_detection.auc_roc.curve
|
||||
|
||||
---
|
||||
|
@ -186,6 +228,7 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { outlier_detection.auc_roc.score: 0.9899 }
|
||||
- match: { outlier_detection.auc_roc.doc_count: 8 }
|
||||
- is_false: outlier_detection.auc_roc.curve
|
||||
|
||||
---
|
||||
|
@ -206,12 +249,13 @@ setup:
|
|||
}
|
||||
}
|
||||
- match: { outlier_detection.auc_roc.score: 0.9899 }
|
||||
- match: { outlier_detection.auc_roc.doc_count: 8 }
|
||||
- is_true: outlier_detection.auc_roc.curve
|
||||
|
||||
---
|
||||
"Test outlier_detection auc_roc given actual_field is always true":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/
|
||||
catch: /\[auc_roc\] requires at least one \[all_true_field\] to have a different value than \[true\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -230,7 +274,7 @@ setup:
|
|||
---
|
||||
"Test outlier_detection auc_roc given actual_field is always false":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/
|
||||
catch: /\[auc_roc\] requires at least one \[all_false_field\] to have the value \[true\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -371,6 +415,7 @@ setup:
|
|||
}
|
||||
}
|
||||
- is_true: outlier_detection.auc_roc.score
|
||||
- is_true: outlier_detection.auc_roc.doc_count
|
||||
- is_true: outlier_detection.precision.0\.25
|
||||
- is_true: outlier_detection.precision.0\.5
|
||||
- is_true: outlier_detection.precision.0\.75
|
||||
|
@ -443,7 +488,7 @@ setup:
|
|||
---
|
||||
"Test outlier_detection given missing actual_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[missing, outlier_score\] fields/
|
||||
catch: /No documents found containing all the required fields \[missing, outlier_score\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -459,7 +504,7 @@ setup:
|
|||
---
|
||||
"Test outlier_detection given missing predicted_probability_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[is_outlier, missing\] fields/
|
||||
catch: /No documents found containing all the required fields \[is_outlier, missing\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -598,7 +643,124 @@ setup:
|
|||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { }
|
||||
"metrics": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test classification auc_roc with missing class_name":
|
||||
- do:
|
||||
# TODO: Revisit this error message as it does not give any clue about which field is missing
|
||||
catch: /Failed to build \[auc_roc\] after last required field arrived/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"top_classes_field": "ml.top_classes",
|
||||
"metrics": {
|
||||
"auc_roc": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test classification auc_roc given actual_field is never equal to fish":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one \[classification_field_act.keyword\] to have the value \[fish\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"top_classes_field": "ml.top_classes",
|
||||
"metrics": {
|
||||
"auc_roc": {
|
||||
"class_name": "fish"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test classification auc_roc given predicted_class_field is never equal to mouse":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"top_classes_field": "ml.top_classes",
|
||||
"metrics": {
|
||||
"auc_roc": {
|
||||
"class_name": "mouse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test classification auc_roc":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"top_classes_field": "ml.top_classes",
|
||||
"metrics": {
|
||||
"auc_roc": {
|
||||
"class_name": "cat"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { classification.auc_roc.score: 0.8050111095212122 }
|
||||
- match: { classification.auc_roc.doc_count: 8 }
|
||||
- is_false: classification.auc_roc.curve
|
||||
---
|
||||
"Test classification auc_roc with default top_classes_field":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"metrics": {
|
||||
"auc_roc": {
|
||||
"class_name": "cat"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { classification.auc_roc.score: 0.8050111095212122 }
|
||||
- match: { classification.auc_roc.doc_count: 8 }
|
||||
- is_false: classification.auc_roc.curve
|
||||
---
|
||||
"Test classification accuracy with missing predicted_field":
|
||||
- do:
|
||||
catch: /\[classification\] must define \[predicted_field\] as required by the following metrics \[accuracy\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"metrics": { "accuracy": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -785,7 +947,7 @@ setup:
|
|||
---
|
||||
"Test classification given missing actual_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/
|
||||
catch: /No documents found containing all the required fields \[missing, classification_field_pred.keyword\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -801,7 +963,7 @@ setup:
|
|||
---
|
||||
"Test classification given missing predicted_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/
|
||||
catch: /No documents found containing all the required fields \[classification_field_act.keyword, missing\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -815,6 +977,27 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test classification given missing top_classes_field":
|
||||
- do:
|
||||
catch: /No documents found containing all the required fields \[classification_field_act.keyword, missing.class_name, missing.class_probability\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"top_classes_field": "missing",
|
||||
"metrics": {
|
||||
"auc_roc": {
|
||||
"class_name": "dummy"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test regression given evaluation with empty metrics":
|
||||
- do:
|
||||
catch: /\[regression\] must have one or more metrics/
|
||||
|
@ -932,7 +1115,7 @@ setup:
|
|||
---
|
||||
"Test regression given missing actual_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[missing, regression_field_pred\] fields/
|
||||
catch: /No documents found containing all the required fields \[missing, regression_field_pred\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
@ -948,7 +1131,7 @@ setup:
|
|||
---
|
||||
"Test regression given missing predicted_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[regression_field_act, missing\] fields/
|
||||
catch: /No documents found containing all the required fields \[regression_field_act, missing\]/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue