[7.x] [ML] Implement AucRoc metric for classification (#60502) (#63051)

This commit is contained in:
Przemysław Witek 2020-09-30 12:55:52 +02:00 committed by GitHub
parent 179fe9cc0e
commit 4366d58564
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 2007 additions and 592 deletions

View File

@ -88,7 +88,7 @@ the probability that each document is an outlier.
`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve. Default value is
{"includes_curve": false}.
{"include_curve": false}.
`confusion_matrix`:::
(Optional, object) Set the different thresholds of the {olscore} at where
@ -153,9 +153,14 @@ belongs.
The data type of this field must be categorical.
`predicted_field`::
(Required, string) The field in the `index` that contains the predicted value,
(Optional, string) The field in the `index` which contains the predicted value,
in other words the results of the {classanalysis}.
`top_classes_field`::
(Optional, string) The field of the `index` which is an array of documents
of the form `{ "class_name": XXX, "class_probability": YYY }`.
This field must be defined as `nested` in the mappings.
`metrics`::
(Optional, object) Specifies the metrics that are used for the evaluation.
Available metrics:
@ -163,6 +168,24 @@ belongs.
`accuracy`:::
(Optional, object) Accuracy of predictions (per-class and overall).
`auc_roc`:::
(Optional, object) The AUC ROC (area under the curve of the receiver
operating characteristic) score and optionally the curve.
It is calculated for a specific class (provided as "class_name")
treated as positive.
`class_name`::::
(Required, string) Name of the only class that will be treated as
positive during AUC ROC calculation. Other classes will be treated as
negative ("one-vs-all" strategy). Documents which do not have `class_name`
in the list of their top classes will not be taken into account for evaluation.
The number of documents taken into account is returned in the evaluation result
(`auc_roc.doc_count` field).
`include_curve`::::
(Optional, boolean) Whether or not the curve should be returned in
addition to the score. Default value is false.
`multiclass_confusion_matrix`:::
(Optional, object) Multiclass confusion matrix.

View File

@ -394,7 +394,16 @@ public class Classification implements DataFrameAnalysis {
return additionalProperties;
}
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
Map<String, Object> topClassesProperties = new HashMap<>();
topClassesProperties.put("class_name", dependentVariableMapping);
topClassesProperties.put("class_probability", Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
Map<String, Object> topClassesMapping = new HashMap<>();
topClassesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
topClassesMapping.put("properties", topClassesProperties);
additionalProperties.put(resultsFieldName + ".top_classes", topClassesMapping);
return additionalProperties;
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
@ -21,11 +22,16 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
/**
* Defines an evaluation
@ -38,14 +44,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
String getName();
/**
* Returns the field containing the actual value
* Returns the collection of fields required by evaluation
*/
String getActualField();
/**
* Returns the field containing the predicted value
*/
String getPredictedField();
EvaluationFields getFields();
/**
* Returns the list of metrics to evaluate
@ -59,27 +60,74 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
}
Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
checkRequiredFieldsAreSet(metrics);
return metrics;
}
default <T extends EvaluationMetric> void checkRequiredFieldsAreSet(List<T> metrics) {
assert (metrics == null || metrics.isEmpty()) == false;
for (Tuple<String, String> requiredField : getFields().listPotentiallyRequiredFields()) {
String fieldDescriptor = requiredField.v1();
String field = requiredField.v2();
if (field == null) {
String metricNamesString =
metrics.stream()
.filter(m -> m.getRequiredFields().contains(fieldDescriptor))
.map(EvaluationMetric::getName)
.collect(joining(", "));
if (metricNamesString.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"[{}] must define [{}] as required by the following metrics [{}]",
getName(), fieldDescriptor, metricNamesString);
}
}
}
}
/**
* Builds the search required to collect data to compute the evaluation result
* @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
*/
default SearchSourceBuilder buildSearch(EvaluationParameters parameters, QueryBuilder userProvidedQueryBuilder) {
Objects.requireNonNull(userProvidedQueryBuilder);
BoolQueryBuilder boolQuery =
QueryBuilders.boolQuery()
// Verify existence of required fields
.filter(QueryBuilders.existsQuery(getActualField()))
.filter(QueryBuilders.existsQuery(getPredictedField()))
Set<String> requiredFields = new HashSet<>(getRequiredFields());
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
if (getFields().getActualField() != null && requiredFields.contains(getFields().getActualField())) {
// Verify existence of the actual field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getActualField()));
}
if (getFields().getPredictedField() != null && requiredFields.contains(getFields().getPredictedField())) {
// Verify existence of the predicted field if required
boolQuery.filter(QueryBuilders.existsQuery(getFields().getPredictedField()));
}
if (getFields().getPredictedClassField() != null && requiredFields.contains(getFields().getPredictedClassField())) {
assert getFields().getTopClassesField() != null;
// Verify existence of the predicted class name field if required
QueryBuilder predictedClassFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedClassField());
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedClassFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
}
if (getFields().getPredictedProbabilityField() != null && requiredFields.contains(getFields().getPredictedProbabilityField())) {
// Verify existence of the predicted probability field if required
QueryBuilder predictedProbabilityFieldExistsQuery = QueryBuilders.existsQuery(getFields().getPredictedProbabilityField());
// predicted probability field may be either nested (just like in case of classification evaluation) or non-nested (just like
// in case of outlier detection evaluation). Here we support both modes.
if (getFields().isPredictedProbabilityFieldNested()) {
assert getFields().getTopClassesField() != null;
boolQuery.filter(
QueryBuilders.nestedQuery(getFields().getTopClassesField(), predictedProbabilityFieldExistsQuery, ScoreMode.None)
.ignoreUnmapped(true));
} else {
boolQuery.filter(predictedProbabilityFieldExistsQuery);
}
}
// Apply user-provided query
.filter(userProvidedQueryBuilder);
boolQuery.filter(userProvidedQueryBuilder);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
metric.aggs(parameters, getActualField(), getPredictedField());
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(parameters, getFields());
aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation);
}
@ -93,14 +141,31 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
default void process(SearchResponse searchResponse) {
Objects.requireNonNull(searchResponse);
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException(
"No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
String requiredFieldsString = String.join(", ", getRequiredFields());
throw ExceptionsHelper.badRequestException("No documents found containing all the required fields [{}]", requiredFieldsString);
}
for (EvaluationMetric metric : getMetrics()) {
metric.process(searchResponse.getAggregations());
}
}
/**
* @return list of fields which are required by at least one of the metrics
*/
default List<String> getRequiredFields() {
Set<String> requiredFieldDescriptors =
getMetrics().stream()
.map(EvaluationMetric::getRequiredFields)
.flatMap(Set::stream)
.collect(toSet());
List<String> requiredFields =
getFields().listPotentiallyRequiredFields().stream()
.filter(f -> requiredFieldDescriptors.contains(f.v1()))
.map(Tuple::v2)
.collect(toList());
return requiredFields;
}
/**
* @return true iff all the metrics have their results computed
*/
@ -117,6 +182,6 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
.map(EvaluationMetric::getResult)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());
.collect(toList());
}
}

View File

@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
/**
* Encapsulates fields needed by evaluation.
*/
public final class EvaluationFields {
public static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
public static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
public static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
public static final ParseField PREDICTED_CLASS_FIELD = new ParseField("predicted_class_field");
public static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");
/**
* The field containing the actual value
*/
private final String actualField;
/**
* The field containing the predicted value
*/
private final String predictedField;
/**
* The field containing the array of top classes
*/
private final String topClassesField;
/**
* The field containing the predicted class name value
*/
private final String predictedClassField;
/**
* The field containing the predicted probability value in [0.0, 1.0]
*/
private final String predictedProbabilityField;
/**
* Whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
private final boolean predictedProbabilityFieldNested;
public EvaluationFields(@Nullable String actualField,
@Nullable String predictedField,
@Nullable String topClassesField,
@Nullable String predictedClassField,
@Nullable String predictedProbabilityField,
boolean predictedProbabilityFieldNested) {
this.actualField = actualField;
this.predictedField = predictedField;
this.topClassesField = topClassesField;
this.predictedClassField = predictedClassField;
this.predictedProbabilityField = predictedProbabilityField;
this.predictedProbabilityFieldNested = predictedProbabilityFieldNested;
}
/**
* Returns the field containing the actual value
*/
public String getActualField() {
return actualField;
}
/**
* Returns the field containing the predicted value
*/
public String getPredictedField() {
return predictedField;
}
/**
* Returns the field containing the array of top classes
*/
public String getTopClassesField() {
return topClassesField;
}
/**
* Returns the field containing the predicted class name value
*/
public String getPredictedClassField() {
return predictedClassField;
}
/**
* Returns the field containing the predicted probability value in [0.0, 1.0]
*/
public String getPredictedProbabilityField() {
return predictedProbabilityField;
}
/**
* Returns whether the {@code predictedProbabilityField} should be treated as nested (e.g.: when used in exists queries).
*/
public boolean isPredictedProbabilityFieldNested() {
return predictedProbabilityFieldNested;
}
public List<Tuple<String, String>> listPotentiallyRequiredFields() {
return Arrays.asList(
Tuple.tuple(ACTUAL_FIELD.getPreferredName(), actualField),
Tuple.tuple(PREDICTED_FIELD.getPreferredName(), predictedField),
Tuple.tuple(TOP_CLASSES_FIELD.getPreferredName(), topClassesField),
Tuple.tuple(PREDICTED_CLASS_FIELD.getPreferredName(), predictedClassField),
Tuple.tuple(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField));
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
EvaluationFields that = (EvaluationFields) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.topClassesField, this.topClassesField)
&& Objects.equals(that.predictedClassField, this.predictedClassField)
&& Objects.equals(that.predictedProbabilityField, this.predictedProbabilityField)
&& Objects.equals(that.predictedProbabilityFieldNested, this.predictedProbabilityFieldNested);
}
@Override
public int hashCode() {
return Objects.hash(
actualField, predictedField, topClassesField, predictedClassField, predictedProbabilityField, predictedProbabilityFieldNested);
}
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import java.util.List;
import java.util.Optional;
import java.util.Set;
/**
* {@link EvaluationMetric} class represents a metric to evaluate.
@ -26,16 +27,18 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
*/
String getName();
/**
* Returns the set of fields that this metric requires in order to be calculated.
*/
Set<String> getRequiredFields();
/**
* Builds the aggregation that collect required data to compute the metric
* @param parameters settings that may be needed by aggregations
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value (class name or probability)
* @param fields fields that may be needed by aggregations
* @return the aggregations required to compute the metric
*/
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField);
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters, EvaluationFields fields);
/**
* Processes given aggregations as a step towards computing result

View File

@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
/**
* Encapsulates parameters needed by evaluation.
*/
public class EvaluationParameters {
public final class EvaluationParameters {
/**
* Maximum number of buckets allowed in any single search request.

View File

@ -10,13 +10,13 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
@ -63,19 +63,28 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Outlier detection metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRoc.NAME)),
AucRoc::fromXContent),
new ParseField(
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, Precision.NAME)),
Precision::fromXContent),
new ParseField(
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, Recall.NAME)),
Recall::fromXContent),
new ParseField(
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME)),
ConfusionMatrix::fromXContent),
// Classification metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, AucRoc.NAME)),
AucRoc::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)),
MulticlassConfusionMatrix::fromXContent),
@ -83,15 +92,11 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)),
Accuracy::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent),
new ParseField(registeredMetricName(Classification.NAME, Precision.NAME)),
Precision::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent),
new ParseField(registeredMetricName(Classification.NAME, Recall.NAME)),
Recall::fromXContent),
// Regression metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class,
@ -124,17 +129,23 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluation metrics
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(OutlierDetection.NAME, AucRoc.NAME),
AucRoc::new),
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(OutlierDetection.NAME, Precision.NAME),
Precision::new),
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(OutlierDetection.NAME, Recall.NAME),
Recall::new),
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME),
ConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Classification.NAME, AucRoc.NAME),
AucRoc::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
MulticlassConfusionMatrix::new),
@ -142,13 +153,11 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
registeredMetricName(Classification.NAME, Accuracy.NAME),
Accuracy::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new),
registeredMetricName(Classification.NAME, Precision.NAME),
Precision::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new),
registeredMetricName(Classification.NAME, Recall.NAME),
Recall::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
MeanSquaredError::new),
@ -163,15 +172,15 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
RSquared::new),
// Evaluation metrics results
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(OutlierDetection.NAME, AucRoc.NAME),
AucRoc.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(OutlierDetection.NAME, ScoreByThresholdResult.NAME),
ScoreByThresholdResult::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(OutlierDetection.NAME, ConfusionMatrix.NAME),
ConfusionMatrix.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Classification.NAME, AucRoc.NAME),
AucRoc.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
MulticlassConfusionMatrix.Result::new),
@ -179,13 +188,11 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
registeredMetricName(Classification.NAME, Accuracy.NAME),
Accuracy.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new),
registeredMetricName(Classification.NAME, Precision.NAME),
Precision.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new),
registeredMetricName(Classification.NAME, Recall.NAME),
Recall.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
MeanSquaredError.Result::new),

View File

@ -0,0 +1,317 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
* The ROC curve is a plot of the TPR (true positive rate) against
* the FPR (false positive rate) over a varying threshold.
*
* This particular implementation is making use of ES aggregations
* to calculate the curve. It then uses the trapezoidal rule to calculate
* the AUC.
*
* In particular, in order to calculate the ROC, we get percentiles of TP
* and FP against the predicted probability. We call those Rate-Threshold
* curves. We then scan ROC points from each Rate-Threshold curve against the
* other using interpolation. This gives us an approximation of the ROC curve
* that has the advantage of being efficient and resilient to some edge cases.
*
* When this is used for multi-class classification, it will calculate the ROC
* curve of each class versus the rest.
*/
public abstract class AbstractAucRoc implements EvaluationMetric {
public static final ParseField NAME = new ParseField("auc_roc");
protected AbstractAucRoc() {}
@Override
public String getName() {
return NAME.getPreferredName();
}
protected static double[] percentilesArray(Percentiles percentiles) {
double[] result = new double[99];
percentiles.forEach(percentile -> {
if (Double.isNaN(percentile.getValue())) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at all the percentiles values to be finite numbers", NAME.getPreferredName());
}
result[((int) percentile.getPercent()) - 1] = percentile.getValue();
});
return result;
}
/**
* Visible for testing
*/
protected static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) {
assert tpPercentiles.length == fpPercentiles.length;
assert tpPercentiles.length == 99;
List<AucRocPoint> aucRocCurve = new ArrayList<>();
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
Collections.sort(aucRocCurve);
return aucRocCurve;
}
/**
* Visible for testing
*/
protected static double calculateAucScore(List<AucRocPoint> rocCurve) {
// Calculates AUC based on the trapezoid rule
double aucRoc = 0.0;
for (int i = 1; i < rocCurve.size(); i++) {
AucRocPoint left = rocCurve.get(i - 1);
AucRocPoint right = rocCurve.get(i);
aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2;
}
return aucRoc;
}
private static class RateThresholdCurve {
private final double[] percentiles;
private final boolean isTp;
private RateThresholdCurve(double[] percentiles, boolean isTp) {
this.percentiles = percentiles;
this.isTp = isTp;
}
private double getRate(int index) {
return 1 - 0.01 * (index + 1);
}
private double getThreshold(int index) {
return percentiles[index];
}
private double interpolateRate(double threshold) {
int binarySearchResult = Arrays.binarySearch(percentiles, threshold);
if (binarySearchResult >= 0) {
return getRate(binarySearchResult);
} else {
int right = (binarySearchResult * -1) -1;
int left = right - 1;
if (right >= percentiles.length) {
return 0.0;
} else if (left < 0) {
return 1.0;
} else {
double rightRate = getRate(right);
double leftRate = getRate(left);
return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate);
}
}
}
private List<AucRocPoint> scanPoints(RateThresholdCurve againstCurve) {
List<AucRocPoint> points = new ArrayList<>();
for (int index = 0; index < percentiles.length; index++) {
double rate = getRate(index);
double scannedThreshold = getThreshold(index);
double againstRate = againstCurve.interpolateRate(scannedThreshold);
AucRocPoint point;
if (isTp) {
point = new AucRocPoint(rate, againstRate, scannedThreshold);
} else {
point = new AucRocPoint(againstRate, rate, scannedThreshold);
}
points.add(point);
}
return points;
}
}
public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXContentObject, Writeable {
private static final String TPR = "tpr";
private static final String FPR = "fpr";
private static final String THRESHOLD = "threshold";
private final double tpr;
private final double fpr;
private final double threshold;
AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
this.fpr = fpr;
this.threshold = threshold;
}
private AucRocPoint(StreamInput in) throws IOException {
this.tpr = in.readDouble();
this.fpr = in.readDouble();
this.threshold = in.readDouble();
}
@Override
public int compareTo(AucRocPoint o) {
return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed()
.thenComparing(p -> p.fpr)
.thenComparing(p -> p.tpr)
.compare(this, o);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(tpr);
out.writeDouble(fpr);
out.writeDouble(threshold);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TPR, tpr);
builder.field(FPR, fpr);
builder.field(THRESHOLD, threshold);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRocPoint that = (AucRocPoint) o;
return tpr == that.tpr
&& fpr == that.fpr
&& threshold == that.threshold;
}
@Override
public int hashCode() {
return Objects.hash(tpr, fpr, threshold);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
private static double interpolate(double x, double x1, double y1, double x2, double y2) {
return y1 + (x - x1) * (y2 - y1) / (x2 - x1);
}
public static class Result implements EvaluationMetricResult {
private static final String SCORE = "score";
private static final String DOC_COUNT = "doc_count";
private static final String CURVE = "curve";
private final double score;
private final Long docCount;
private final List<AucRocPoint> curve;
public Result(double score, Long docCount, List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = Objects.requireNonNull(curve);
}
public Result(StreamInput in) throws IOException {
this.score = in.readDouble();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.docCount = in.readOptionalLong();
} else {
this.docCount = null;
}
this.curve = in.readList(AucRocPoint::new);
}
public double getScore() {
return score;
}
public Long getDocCount() {
return docCount;
}
public List<AucRocPoint> getCurve() {
return Collections.unmodifiableList(curve);
}
@Override
public String getWriteableName() {
return registeredMetricName(Classification.NAME, NAME);
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeOptionalLong(docCount);
}
out.writeList(curve);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SCORE, score);
if (docCount != null) {
builder.field(DOC_COUNT, docCount);
}
if (curve.isEmpty() == false) {
builder.field(CURVE, curve);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& Objects.equals(docCount, that.docCount)
&& Objects.equals(curve, that.curve);
}
@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
}
}
}

View File

@ -11,6 +11,7 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
@ -22,6 +23,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -33,6 +35,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -95,21 +98,24 @@ public class Accuracy implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
this.actualField.trySet(fields.getActualField());
List<AggregationBuilder> aggs = new ArrayList<>();
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
if (overallAccuracy.get() == null) {
Script script = PainlessScripts.buildIsEqualScript(actualField, predictedField);
Script script = PainlessScripts.buildIsEqualScript(fields.getActualField(), fields.getPredictedField());
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(script));
}
if (result.get() == null) {
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs =
matrix.aggs(parameters, actualField, predictedField);
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(parameters, fields);
aggs.addAll(matrixAggs.v1());
pipelineAggs.addAll(matrixAggs.v2());
}

View File

@ -0,0 +1,228 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.bucket.nested.Nested;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
* The ROC curve is a plot of the TPR (true positive rate) against
* the FPR (false positive rate) over a varying threshold.
*
* This particular implementation is making use of ES aggregations
* to calculate the curve. It then uses the trapezoidal rule to calculate
* the AUC.
*
* In particular, in order to calculate the ROC, we get percentiles of TP
* and FP against the predicted probability. We call those Rate-Threshold
* curves. We then scan ROC points from each Rate-Threshold curve against the
* other using interpolation. This gives us an approximation of the ROC curve
* that has the advantage of being efficient and resilient to some edge cases.
*
* When this is used for multi-class classification, it will calculate the ROC
* curve of each class versus the rest.
*/
public class AucRoc extends AbstractAucRoc {
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ConstructingObjectParser<AucRoc, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), a -> new AucRoc((Boolean) a[0], (String) a[1]));
static {
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), CLASS_NAME);
}
private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
private static final String NESTED_AGG_NAME = "nested";
private static final String NESTED_FILTER_AGG_NAME = "nested_filter";
private static final String PERCENTILES_AGG_NAME = "percentiles";
public static AucRoc fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final boolean includeCurve;
private final String className;
private final SetOnce<EvaluationFields> fields = new SetOnce<>();
private final SetOnce<EvaluationMetricResult> result = new SetOnce<>();
public AucRoc(Boolean includeCurve, String className) {
this.includeCurve = includeCurve == null ? false : includeCurve;
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME.getPreferredName());
}
public AucRoc(StreamInput in) throws IOException {
this.includeCurve = in.readBoolean();
this.className = in.readOptionalString();
}
@Override
public String getWriteableName() {
return registeredMetricName(Classification.NAME, NAME);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(includeCurve);
out.writeOptionalString(className);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
if (className != null) {
builder.field(CLASS_NAME.getPreferredName(), className);
}
builder.endObject();
return builder;
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(
EvaluationFields.ACTUAL_FIELD.getPreferredName(),
EvaluationFields.PREDICTED_CLASS_FIELD.getPreferredName(),
EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRoc that = (AucRoc) o;
return includeCurve == that.includeCurve
&& Objects.equals(className, that.className);
}
@Override
public int hashCode() {
return Objects.hash(includeCurve, className);
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
EvaluationFields fields) {
if (result.get() != null) {
return Tuple.tuple(Arrays.asList(), Arrays.asList());
}
// Store given {@code fields} for the purpose of generating error messages in {@code process}.
this.fields.trySet(fields);
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
AggregationBuilder percentilesAgg =
AggregationBuilders
.percentiles(PERCENTILES_AGG_NAME)
.field(fields.getPredictedProbabilityField())
.percentiles(percentiles);
AggregationBuilder nestedAgg =
AggregationBuilders
.nested(NESTED_AGG_NAME, fields.getTopClassesField())
.subAggregation(
AggregationBuilders
.filter(NESTED_FILTER_AGG_NAME, QueryBuilders.termQuery(fields.getPredictedClassField(), className))
.subAggregation(percentilesAgg));
QueryBuilder actualIsTrueQuery = QueryBuilders.termQuery(fields.getActualField(), className);
AggregationBuilder percentilesForClassValueAgg =
AggregationBuilders
.filter(TRUE_AGG_NAME, actualIsTrueQuery)
.subAggregation(nestedAgg);
AggregationBuilder percentilesForRestAgg =
AggregationBuilders
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery))
.subAggregation(nestedAgg);
return Tuple.tuple(
Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg),
Arrays.asList());
}
@Override
public void process(Aggregations aggs) {
if (result.get() != null) {
return;
}
Filter classAgg = aggs.get(TRUE_AGG_NAME);
Nested classNested = classAgg.getAggregations().get(NESTED_AGG_NAME);
Filter classNestedFilter = classNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (classAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getActualField(), className);
}
if (classNestedFilter.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
}
Percentiles classPercentiles = classNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] tpPercentiles = percentilesArray(classPercentiles);
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
Nested restNested = restAgg.getAggregations().get(NESTED_AGG_NAME);
Filter restNestedFilter = restNested.getAggregations().get(NESTED_FILTER_AGG_NAME);
if (restAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have a different value than [{}]",
getName(), fields.get().getActualField(), className);
}
if (restNestedFilter.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]",
getName(), fields.get().getPredictedClassField(), className);
}
Percentiles restPercentiles = restNestedFilter.getAggregations().get(PERCENTILES_AGG_NAME);
double[] fpPercentiles = percentilesArray(restPercentiles);
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result.set(
new Result(
aucRocScore,
classNestedFilter.getDocCount() + restNestedFilter.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result.get());
}
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -13,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -21,6 +23,9 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.TOP_CLASSES_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
@ -30,17 +35,22 @@ public class Classification implements Evaluation {
public static final ParseField NAME = new ParseField("classification");
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField METRICS = new ParseField("metrics");
private static final String DEFAULT_TOP_CLASSES_FIELD = "ml.top_classes";
private static final String DEFAULT_PREDICTED_CLASS_FIELD_SUFFIX = ".class_name";
private static final String DEFAULT_PREDICTED_PROBABILITY_FIELD_SUFFIX = ".class_probability";
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
public static final ConstructingObjectParser<Classification, Void> PARSER =
new ConstructingObjectParser<>(
NAME.getPreferredName(),
a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List<EvaluationMetric>) a[3]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTED_FIELD);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TOP_CLASSES_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
}
@ -50,25 +60,35 @@ public class Classification implements Evaluation {
}
/**
* The field containing the actual value
* The value of this field is assumed to be categorical
* The collection of fields in the index being evaluated.
* fields.getActualField() is assumed to be a ground truth label.
* fields.getPredictedField() is assumed to be a predicted label.
* fields.getPredictedClassField() and fields.getPredictedProbabilityField() are assumed to be properties under the same nested field.
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be categorical
*/
private final String predictedField;
private final EvaluationFields fields;
/**
* The list of metrics to calculate
*/
private final List<EvaluationMetric> metrics;
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
public Classification(String actualField,
@Nullable String predictedField,
@Nullable String topClassesField,
@Nullable List<EvaluationMetric> metrics) {
if (topClassesField == null) {
topClassesField = DEFAULT_TOP_CLASSES_FIELD;
}
String predictedClassField = topClassesField + DEFAULT_PREDICTED_CLASS_FIELD_SUFFIX;
String predictedProbabilityField = topClassesField + DEFAULT_PREDICTED_PROBABILITY_FIELD_SUFFIX;
this.fields =
new EvaluationFields(
ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD),
predictedField,
topClassesField,
predictedClassField,
predictedProbabilityField,
true);
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
}
@ -77,8 +97,18 @@ public class Classification implements Evaluation {
}
public Classification(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedField = in.readString();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.fields =
new EvaluationFields(
in.readString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
true);
} else {
this.fields = new EvaluationFields(in.readString(), in.readString(), null, null, null, true);
}
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
}
@ -88,13 +118,8 @@ public class Classification implements Evaluation {
}
@Override
public String getActualField() {
return actualField;
}
@Override
public String getPredictedField() {
return predictedField;
public EvaluationFields getFields() {
return fields;
}
@Override
@ -109,17 +134,28 @@ public class Classification implements Evaluation {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualField);
out.writeString(predictedField);
out.writeString(fields.getActualField());
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeOptionalString(fields.getPredictedField());
out.writeOptionalString(fields.getTopClassesField());
out.writeOptionalString(fields.getPredictedClassField());
out.writeOptionalString(fields.getPredictedProbabilityField());
} else {
out.writeString(fields.getPredictedField());
}
out.writeNamedWriteableList(metrics);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField());
if (fields.getPredictedField() != null) {
builder.field(PREDICTED_FIELD.getPreferredName(), fields.getPredictedField());
}
if (fields.getTopClassesField() != null) {
builder.field(TOP_CLASSES_FIELD.getPreferredName(), fields.getTopClassesField());
}
builder.startObject(METRICS.getPreferredName());
for (EvaluationMetric metric : metrics) {
builder.field(metric.getName(), metric);
@ -135,13 +171,12 @@ public class Classification implements Evaluation {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Classification that = (Classification) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
return Objects.equals(that.fields, this.fields)
&& Objects.equals(that.metrics, this.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedField, metrics);
return Objects.hash(fields, metrics);
}
}

View File

@ -13,6 +13,7 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -27,6 +28,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -39,6 +41,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static java.util.Comparator.comparing;
@ -125,10 +128,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return size;
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
if (topActualClassNames.get() == null && actualClassesCardinality.get() == null) { // This is step 1
return Tuple.tuple(
Arrays.asList(

View File

@ -11,6 +11,7 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
@ -28,6 +29,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -40,6 +42,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@ -90,10 +93,16 @@ public class Precision implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
if (topActualClassNames.get() == null) { // This is step 1

View File

@ -11,6 +11,7 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
@ -25,6 +26,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -37,6 +39,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -84,10 +87,16 @@ public class Recall implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
if (result.get() != null) {

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
@ -17,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -26,6 +28,7 @@ import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection.actualIsTrueQuery;
@ -66,13 +69,20 @@ abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
return builder;
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(
EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedProbabilityField) {
EvaluationFields fields) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
String actualField = fields.getActualField();
String predictedProbabilityField = fields.getPredictedProbabilityField();
return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList());
}

View File

@ -5,14 +5,13 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
@ -21,20 +20,19 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -58,30 +56,28 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetect
* When this is used for multi-class classification, it will calculate the ROC
* curve of each class versus the rest.
*/
public class AucRoc implements EvaluationMetric {
public static final ParseField NAME = new ParseField("auc_roc");
public class AucRoc extends AbstractAucRoc {
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
public static final ConstructingObjectParser<AucRoc, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
a -> new AucRoc((Boolean) a[0]));
public static final ConstructingObjectParser<AucRoc, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), a -> new AucRoc((Boolean) a[0]));
static {
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
}
private static final String PERCENTILES = "percentiles";
private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
private static final String PERCENTILES_AGG_NAME = "percentiles";
public static AucRoc fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final boolean includeCurve;
private EvaluationMetricResult result;
private final SetOnce<EvaluationFields> fields = new SetOnce<>();
private final SetOnce<EvaluationMetricResult> result = new SetOnce<>();
public AucRoc(Boolean includeCurve) {
this.includeCurve = includeCurve == null ? false : includeCurve;
@ -110,8 +106,9 @@ public class AucRoc implements EvaluationMetric {
}
@Override
public String getName() {
return NAME.getPreferredName();
public Set<String> getRequiredFields() {
return Sets.newHashSet(
EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName());
}
@Override
@ -119,7 +116,7 @@ public class AucRoc implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRoc that = (AucRoc) o;
return Objects.equals(includeCurve, that.includeCurve);
return includeCurve == that.includeCurve;
}
@Override
@ -129,22 +126,29 @@ public class AucRoc implements EvaluationMetric {
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedProbabilityField) {
if (result != null) {
EvaluationFields fields) {
if (result.get() != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
// Store given {@code fields} for the purpose of generating error messages in {@code process}.
this.fields.trySet(fields);
String actualField = fields.getActualField();
String predictedProbabilityField = fields.getPredictedProbabilityField();
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
AggregationBuilder percentilesAgg =
AggregationBuilders
.percentiles(PERCENTILES_AGG_NAME)
.field(predictedProbabilityField)
.percentiles(percentiles);
AggregationBuilder percentilesForClassValueAgg =
AggregationBuilders
.filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField))
.subAggregation(
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
.subAggregation(percentilesAgg);
AggregationBuilder percentilesForRestAgg =
AggregationBuilders
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
.subAggregation(
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
.subAggregation(percentilesAgg);
return Tuple.tuple(
Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg),
Collections.emptyList());
@ -152,216 +156,33 @@ public class AucRoc implements EvaluationMetric {
@Override
public void process(Aggregations aggs) {
if (result.get() != null) {
return;
}
Filter classAgg = aggs.get(TRUE_AGG_NAME);
if (classAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have the value [{}]", getName(), fields.get().getActualField(), "true");
}
double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES_AGG_NAME));
Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
double[] tpPercentiles =
percentilesArray(
classAgg.getAggregations().get(PERCENTILES),
"[" + getName() + "] requires at least one actual_field to have the value [true]");
double[] fpPercentiles =
percentilesArray(
restAgg.getAggregations().get(PERCENTILES),
"[" + getName() + "] requires at least one actual_field to have a different value than [true]");
if (restAgg.getDocCount() == 0) {
throw ExceptionsHelper.badRequestException(
"[{}] requires at least one [{}] to have a different value than [{}]", getName(), fields.get().getActualField(), "true");
}
double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES_AGG_NAME));
List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = calculateAucScore(aucRocCurve);
result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
result.set(
new Result(
aucRocScore,
classAgg.getDocCount() + restAgg.getDocCount(),
includeCurve ? aucRocCurve : Collections.emptyList()));
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
double[] result = new double[99];
percentiles.forEach(percentile -> {
if (Double.isNaN(percentile.getValue())) {
throw ExceptionsHelper.badRequestException(errorIfUndefined);
}
result[((int) percentile.getPercent()) - 1] = percentile.getValue();
});
return result;
}
/**
* Visible for testing
*/
static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) {
assert tpPercentiles.length == fpPercentiles.length;
assert tpPercentiles.length == 99;
List<AucRocPoint> aucRocCurve = new ArrayList<>();
aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
aucRocCurve.addAll(tpCurve.scanPoints(fpCurve));
aucRocCurve.addAll(fpCurve.scanPoints(tpCurve));
Collections.sort(aucRocCurve);
return aucRocCurve;
}
/**
* Visible for testing
*/
static double calculateAucScore(List<AucRocPoint> rocCurve) {
// Calculates AUC based on the trapezoid rule
double aucRoc = 0.0;
for (int i = 1; i < rocCurve.size(); i++) {
AucRocPoint left = rocCurve.get(i - 1);
AucRocPoint right = rocCurve.get(i);
aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2;
}
return aucRoc;
}
private static class RateThresholdCurve {
private final double[] percentiles;
private final boolean isTp;
private RateThresholdCurve(double[] percentiles, boolean isTp) {
this.percentiles = percentiles;
this.isTp = isTp;
}
private double getRate(int index) {
return 1 - 0.01 * (index + 1);
}
private double getThreshold(int index) {
return percentiles[index];
}
private double interpolateRate(double threshold) {
int binarySearchResult = Arrays.binarySearch(percentiles, threshold);
if (binarySearchResult >= 0) {
return getRate(binarySearchResult);
} else {
int right = (binarySearchResult * -1) -1;
int left = right - 1;
if (right >= percentiles.length) {
return 0.0;
} else if (left < 0) {
return 1.0;
} else {
double rightRate = getRate(right);
double leftRate = getRate(left);
return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate);
}
}
}
private List<AucRocPoint> scanPoints(RateThresholdCurve againstCurve) {
List<AucRocPoint> points = new ArrayList<>();
for (int index = 0; index < percentiles.length; index++) {
double rate = getRate(index);
double scannedThreshold = getThreshold(index);
double againstRate = againstCurve.interpolateRate(scannedThreshold);
AucRocPoint point;
if (isTp) {
point = new AucRocPoint(rate, againstRate, scannedThreshold);
} else {
point = new AucRocPoint(againstRate, rate, scannedThreshold);
}
points.add(point);
}
return points;
}
}
public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXContentObject, Writeable {
double tpr;
double fpr;
double threshold;
private AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
this.fpr = fpr;
this.threshold = threshold;
}
private AucRocPoint(StreamInput in) throws IOException {
this.tpr = in.readDouble();
this.fpr = in.readDouble();
this.threshold = in.readDouble();
}
@Override
public int compareTo(AucRocPoint o) {
return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed()
.thenComparing(p -> p.fpr)
.thenComparing(p -> p.tpr)
.compare(this, o);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(tpr);
out.writeDouble(fpr);
out.writeDouble(threshold);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("tpr", tpr);
builder.field("fpr", fpr);
builder.field("threshold", threshold);
builder.endObject();
return builder;
}
@Override
public String toString() {
return Strings.toString(this);
}
}
private static double interpolate(double x, double x1, double y1, double x2, double y2) {
return y1 + (x - x1) * (y2 - y1) / (x2 - x1);
}
public static class Result implements EvaluationMetricResult {
private final double score;
private final List<AucRocPoint> curve;
public Result(double score, List<AucRocPoint> curve) {
this.score = score;
this.curve = Objects.requireNonNull(curve);
}
public Result(StreamInput in) throws IOException {
this.score = in.readDouble();
this.curve = in.readList(AucRocPoint::new);
}
@Override
public String getWriteableName() {
return registeredMetricName(OutlierDetection.NAME, NAME);
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(score);
out.writeList(curve);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("score", score);
if (curve.isEmpty() == false) {
builder.field("curve", curve);
}
builder.endObject();
return builder;
}
return Optional.ofNullable(result.get());
}
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -23,6 +24,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_PROBABILITY_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
@ -32,8 +35,6 @@ public class OutlierDetection implements Evaluation {
public static final ParseField NAME = new ParseField("outlier_detection", "binary_soft_classification");
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field");
private static final ParseField METRICS = new ParseField("metrics");
public static final ConstructingObjectParser<OutlierDetection, Void> PARSER = new ConstructingObjectParser<>(
@ -50,30 +51,34 @@ public class OutlierDetection implements Evaluation {
return PARSER.apply(parser, null);
}
static QueryBuilder actualIsTrueQuery(String actualField) {
public static QueryBuilder actualIsTrueQuery(String actualField) {
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
}
/**
* The field where the actual class is marked up.
* The value of this field is assumed to either be 1 or 0, or true or false.
* The collection of fields in the index being evaluated.
* fields.getActualField() is assumed to either be 1 or 0, or true or false.
* fields.getPredictedProbabilityField() is assumed to be a number in [0.0, 1.0].
* Other fields are not needed by this evaluation.
*/
private final String actualField;
/**
* The field of the predicted probability in [0.0, 1.0].
*/
private final String predictedProbabilityField;
private final EvaluationFields fields;
/**
* The list of metrics to calculate
*/
private final List<EvaluationMetric> metrics;
public OutlierDetection(String actualField, String predictedProbabilityField,
public OutlierDetection(String actualField,
String predictedProbabilityField,
@Nullable List<EvaluationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
this.fields =
new EvaluationFields(
ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD),
null,
null,
null,
ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD),
false);
this.metrics = initMetrics(metrics, OutlierDetection::defaultMetrics);
}
@ -86,8 +91,7 @@ public class OutlierDetection implements Evaluation {
}
public OutlierDetection(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedProbabilityField = in.readString();
this.fields = new EvaluationFields(in.readString(), null, null, null, in.readString(), false);
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
}
@ -97,13 +101,8 @@ public class OutlierDetection implements Evaluation {
}
@Override
public String getActualField() {
return actualField;
}
@Override
public String getPredictedField() {
return predictedProbabilityField;
public EvaluationFields getFields() {
return fields;
}
@Override
@ -118,16 +117,16 @@ public class OutlierDetection implements Evaluation {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualField);
out.writeString(predictedProbabilityField);
out.writeString(fields.getActualField());
out.writeString(fields.getPredictedProbabilityField());
out.writeNamedWriteableList(metrics);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField());
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), fields.getPredictedProbabilityField());
builder.startObject(METRICS.getPreferredName());
for (EvaluationMetric metric : metrics) {
@ -144,13 +143,12 @@ public class OutlierDetection implements Evaluation {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OutlierDetection that = (OutlierDetection) o;
return Objects.equals(actualField, that.actualField)
&& Objects.equals(predictedProbabilityField, that.predictedProbabilityField)
return Objects.equals(fields, that.fields)
&& Objects.equals(metrics, that.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedProbabilityField, metrics);
return Objects.hash(fields, metrics);
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -31,6 +33,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -86,13 +89,19 @@ public class Huber implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
return Tuple.tuple(
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, delta * delta)))),
Collections.emptyList());

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -19,6 +20,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -31,6 +33,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -70,13 +73,19 @@ public class MeanSquaredError implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
return Tuple.tuple(
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))),
Collections.emptyList());

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -31,6 +33,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -85,13 +88,19 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
return Tuple.tuple(
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, offset)))),
Collections.emptyList());

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -20,6 +21,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -32,6 +34,7 @@ import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
@ -74,13 +77,19 @@ public class RSquared implements EvaluationMetric {
return NAME.getPreferredName();
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
String actualField = fields.getActualField();
String predictedField = fields.getPredictedField();
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),

View File

@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -21,6 +22,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.ACTUAL_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields.PREDICTED_FIELD;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
@ -30,8 +33,6 @@ public class Regression implements Evaluation {
public static final ParseField NAME = new ParseField("regression");
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField METRICS = new ParseField("metrics");
@SuppressWarnings("unchecked")
@ -50,16 +51,12 @@ public class Regression implements Evaluation {
}
/**
* The field containing the actual value
* The value of this field is assumed to be numeric
* The collection of fields in the index being evaluated.
* fields.getActualField() is assumed to be numeric.
* fields.getPredictedField() is assumed to be numeric.
* Other fields are not needed by this evaluation.
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
*/
private final String predictedField;
private final EvaluationFields fields;
/**
* The list of metrics to calculate
@ -67,8 +64,14 @@ public class Regression implements Evaluation {
private final List<EvaluationMetric> metrics;
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.fields =
new EvaluationFields(
ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD),
ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD),
null,
null,
null,
false);
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
}
@ -77,8 +80,7 @@ public class Regression implements Evaluation {
}
public Regression(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedField = in.readString();
this.fields = new EvaluationFields(in.readString(), in.readString(), null, null, null, false);
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
}
@ -88,13 +90,8 @@ public class Regression implements Evaluation {
}
@Override
public String getActualField() {
return actualField;
}
@Override
public String getPredictedField() {
return predictedField;
public EvaluationFields getFields() {
return fields;
}
@Override
@ -109,16 +106,16 @@ public class Regression implements Evaluation {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualField);
out.writeString(predictedField);
out.writeString(fields.getActualField());
out.writeString(fields.getPredictedField());
out.writeNamedWriteableList(metrics);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
builder.field(ACTUAL_FIELD.getPreferredName(), fields.getActualField());
builder.field(PREDICTED_FIELD.getPreferredName(), fields.getPredictedField());
builder.startObject(METRICS.getPreferredName());
for (EvaluationMetric metric : metrics) {
@ -135,13 +132,12 @@ public class Regression implements Evaluation {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
return Objects.equals(that.fields, this.fields)
&& Objects.equals(that.metrics, this.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedField, metrics);
return Objects.hash(fields, metrics);
}
}

View File

@ -11,13 +11,14 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRocResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import java.util.Arrays;
@ -25,6 +26,10 @@ import java.util.List;
public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializingTestCase<Response> {
private static final String OUTLIER_DETECTION = "outlier_detection";
private static final String CLASSIFICATION = "classification";
private static final String REGRESSION = "regression";
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
@ -32,18 +37,35 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
@Override
protected Response createTestInstance() {
String evaluationName = randomAlphaOfLength(10);
List<EvaluationMetricResult> metrics =
String evaluationName = randomFrom(OUTLIER_DETECTION, CLASSIFICATION, REGRESSION);
List<EvaluationMetricResult> metrics;
switch (evaluationName) {
case OUTLIER_DETECTION:
metrics = randomSubsetOf(
Arrays.asList(
AucRocResultTests.createRandom()));
break;
case CLASSIFICATION:
metrics = randomSubsetOf(
Arrays.asList(
AucRocResultTests.createRandom(),
AccuracyResultTests.createRandom(),
PrecisionResultTests.createRandom(),
RecallResultTests.createRandom(),
MulticlassConfusionMatrixResultTests.createRandom(),
MulticlassConfusionMatrixResultTests.createRandom()));
break;
case REGRESSION:
metrics = randomSubsetOf(
Arrays.asList(
new MeanSquaredError.Result(randomDouble()),
new MeanSquaredLogarithmicError.Result(randomDouble()),
new Huber.Result(randomDouble()),
new RSquared.Result(randomDouble()));
return new Response(evaluationName, randomSubsetOf(metrics));
new RSquared.Result(randomDouble())));
break;
default:
throw new AssertionError("Please add missing \"case\" variant to the \"switch\" statement");
}
return new Response(evaluationName, metrics);
}
@Override

View File

@ -44,7 +44,6 @@ import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
@ -366,15 +365,27 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
assertThat(
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
equalTo(Collections.singletonMap("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING)));
Map<String, Object> expectedTopClassesMapping = new HashMap<String, Object>() {{
put("type", "nested");
put("properties", new HashMap<String, Object>() {{
put("class_name", Collections.singletonMap("bar", "baz"));
put("class_probability", Collections.singletonMap("type", "double"));
}});
}};
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
"results");
assertThat(explicitlyMappedFields,
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")));
assertThat(explicitlyMappedFields, hasEntry("results.top_classes", expectedTopClassesMapping));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING));
expectedTopClassesMapping = new HashMap<String, Object>() {{
put("type", "nested");
put("properties", new HashMap<String, Object>() {{
put("class_name", Collections.singletonMap("type", "long"));
put("class_probability", Collections.singletonMap("type", "double"));
}});
}};
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
new HashMap<String, Object>() {{
put("foo", new HashMap<String, String>() {{
@ -384,10 +395,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
put("bar", Collections.singletonMap("type", "long"));
}},
"results");
assertThat(explicitlyMappedFields,
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")));
assertThat(explicitlyMappedFields, hasEntry("results.top_classes", expectedTopClassesMapping));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", Classification.FEATURE_IMPORTANCE_MAPPING));
assertThat(

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.test.ESTestCase;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
public class EvaluationFieldsTests extends ESTestCase {
public void testConstructorAndGetters() {
EvaluationFields fields = new EvaluationFields("a", "b", "c", "d", "e", true);
assertThat(fields.getActualField(), is(equalTo("a")));
assertThat(fields.getPredictedField(), is(equalTo("b")));
assertThat(fields.getTopClassesField(), is(equalTo("c")));
assertThat(fields.getPredictedClassField(), is(equalTo("d")));
assertThat(fields.getPredictedProbabilityField(), is(equalTo("e")));
assertThat(fields.isPredictedProbabilityFieldNested(), is(true));
}
public void testConstructorAndGetters_WithNullValues() {
EvaluationFields fields = new EvaluationFields("a", null, "c", null, "e", true);
assertThat(fields.getActualField(), is(equalTo("a")));
assertThat(fields.getPredictedField(), is(nullValue()));
assertThat(fields.getTopClassesField(), is(equalTo("c")));
assertThat(fields.getPredictedClassField(), is(nullValue()));
assertThat(fields.getPredictedProbabilityField(), is(equalTo("e")));
assertThat(fields.isPredictedProbabilityFieldNested(), is(true));
}
public void testListPotentiallyRequiredFields() {
EvaluationFields fields = new EvaluationFields("a", "b", "c", "d", "e", randomBoolean());
assertThat(fields.listPotentiallyRequiredFields().stream().map(Tuple::v2).collect(toList()), contains("a", "b", "c", "d", "e"));
}
public void testListPotentiallyRequiredFields_WithNullValues() {
EvaluationFields fields = new EvaluationFields("a", null, "c", null, "e", randomBoolean());
assertThat(fields.listPotentiallyRequiredFields().stream().map(Tuple::v2).collect(toList()), contains("a", null, "c", null, "e"));
}
}

View File

@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.test.ESTestCase;
import java.util.Arrays;
import java.util.List;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class AbstractAucRocTests extends ESTestCase {
public void testCalculateAucScore_GivenZeroPercentiles() {
double[] tpPercentiles = zeroPercentiles();
double[] fpPercentiles = zeroPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
assertThat(aucRocScore, closeTo(0.5, 0.01));
}
public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() {
double[] tpPercentiles = randomPercentiles();
double[] fpPercentiles = zeroPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
assertThat(aucRocScore, closeTo(1.0, 0.1));
}
public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() {
double[] tpPercentiles = zeroPercentiles();
double[] fpPercentiles = randomPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
assertThat(aucRocScore, closeTo(0.0, 0.1));
}
public void testCalculateAucScore_GivenRandomPercentiles() {
for (int i = 0; i < 20; i++) {
double[] tpPercentiles = randomPercentiles();
double[] fpPercentiles = randomPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
assertThat(aucRocScore, greaterThanOrEqualTo(0.0));
assertThat(aucRocScore, lessThanOrEqualTo(1.0));
assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0));
assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0));
assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05));
}
}
public void testCalculateAucScore_GivenPrecalculated() {
double[] tpPercentiles = new double[99];
double[] fpPercentiles = new double[99];
double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 };
double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 };
for (int i = 0; i < tpPercentiles.length; i++) {
int simplifiedIndex = i / 25;
tpPercentiles[i] = tpSimplified[simplifiedIndex];
fpPercentiles[i] = fpSimplified[simplifiedIndex];
}
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
assertThat(aucRocScore, closeTo(0.8, 0.05));
assertThat(inverseAucRocScore, closeTo(0.2, 0.05));
}
public static double[] zeroPercentiles() {
double[] percentiles = new double[99];
Arrays.fill(percentiles, 0.0);
return percentiles;
}
public static double[] randomPercentiles() {
double[] percentiles = new double[99];
for (int i = 0; i < percentiles.length; i++) {
percentiles[i] = randomDouble();
}
Arrays.sort(percentiles);
return percentiles;
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
@ -32,6 +33,7 @@ import static org.hamcrest.Matchers.equalTo;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
@Override
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
@ -88,7 +90,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
Accuracy accuracy = new Accuracy();
accuracy.process(aggs);
assertThat(accuracy.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(accuracy.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
Result result = accuracy.getResult().get();
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
@ -130,7 +132,7 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy();
accuracy.aggs(EVALUATION_PARAMETERS, "foo", "bar");
accuracy.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}

View File

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.AucRocPoint;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AbstractAucRoc.Result;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class AucRocResultTests extends AbstractWireSerializingTestCase<Result> {
public static Result createRandom() {
double score = randomDoubleBetween(0.0, 1.0, true);
Long docCount = randomBoolean() ? randomLong() : null;
List<AucRocPoint> curve =
Stream
.generate(() -> new AucRocPoint(randomDouble(), randomDouble(), randomDouble()))
.limit(randomIntBetween(0, 20))
.collect(Collectors.toList());
return new Result(score, docCount, curve);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
protected Result createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Result> instanceReader() {
return Result::new;
}
}

View File

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
public class AucRocTests extends AbstractSerializingTestCase<AucRoc> {
@Override
protected AucRoc doParseInstance(XContentParser parser) throws IOException {
return AucRoc.PARSER.apply(parser, null);
}
@Override
protected AucRoc createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<AucRoc> instanceReader() {
return AucRoc::new;
}
public static AucRoc createRandom() {
return new AucRoc(randomBoolean() ? randomBoolean() : null, randomAlphaOfLength(randomIntBetween(2, 10)));
}
}

View File

@ -6,12 +6,14 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -23,6 +25,7 @@ import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -33,6 +36,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
@ -61,10 +65,17 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
randomSubsetOf(
Arrays.asList(
AccuracyTests.createRandom(),
AucRocTests.createRandom(),
PrecisionTests.createRandom(),
RecallTests.createRandom(),
MulticlassConfusionMatrixTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
boolean usesAucRoc = metrics.stream().map(EvaluationMetric::getName).anyMatch(n -> AucRoc.NAME.getPreferredName().equals(n));
return new Classification(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
// If AucRoc is to be calculated, the top_classes field is required
(usesAucRoc || randomBoolean()) ? randomAlphaOfLength(10) : null,
metrics.isEmpty() ? null : metrics);
}
@Override
@ -82,13 +93,35 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
return Classification::new;
}
public void testConstructor_GivenMissingField() {
FakeClassificationMetric metric = new FakeClassificationMetric("fake");
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> new Classification("foo", null, null, Collections.singletonList(metric)));
assertThat(
e.getMessage(),
is(equalTo("[classification] must define [predicted_field] as required by the following metrics [fake]")));
}
public void testConstructor_GivenEmptyMetrics() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", "bar", Collections.emptyList()));
() -> new Classification("foo", "bar", "results", Collections.emptyList()));
assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics"));
}
public void testBuildSearch() {
public void testGetFields() {
Classification evaluation = new Classification("foo", "bar", "results", null);
EvaluationFields fields = evaluation.getFields();
assertThat(fields.getActualField(), is(equalTo("foo")));
assertThat(fields.getPredictedField(), is(equalTo("bar")));
assertThat(fields.getTopClassesField(), is(equalTo("results")));
assertThat(fields.getPredictedClassField(), is(equalTo("results.class_name")));
assertThat(fields.getPredictedProbabilityField(), is(equalTo("results.class_probability")));
assertThat(fields.isPredictedProbabilityFieldNested(), is(true));
}
public void testBuildSearch_WithDefaultNonRequiredNestedFields() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
@ -101,7 +134,78 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
Classification evaluation = new Classification("act", "pred", null, Arrays.asList(new MulticlassConfusionMatrix()));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
}
public void testBuildSearch_WithExplicitNonRequiredNestedFields() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(QueryBuilders.existsQuery("pred"))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
Classification evaluation = new Classification("act", "pred", "results", Arrays.asList(new MulticlassConfusionMatrix()));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
}
public void testBuildSearch_WithDefaultRequiredNestedFields() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(
QueryBuilders.nestedQuery("ml.top_classes", QueryBuilders.existsQuery("ml.top_classes.class_name"), ScoreMode.None)
.ignoreUnmapped(true))
.filter(
QueryBuilders.nestedQuery(
"ml.top_classes", QueryBuilders.existsQuery("ml.top_classes.class_probability"), ScoreMode.None)
.ignoreUnmapped(true))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
Classification evaluation = new Classification("act", "pred", null, Arrays.asList(new AucRoc(false, "some-value")));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
}
public void testBuildSearch_WithExplicitRequiredNestedFields() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(
QueryBuilders.nestedQuery("results", QueryBuilders.existsQuery("results.class_name"), ScoreMode.None)
.ignoreUnmapped(true))
.filter(
QueryBuilders.nestedQuery("results", QueryBuilders.existsQuery("results.class_probability"), ScoreMode.None)
.ignoreUnmapped(true))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
Classification evaluation = new Classification("act", "pred", "results", Arrays.asList(new AucRoc(false, "some-value")));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(EVALUATION_PARAMETERS, userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
@ -114,7 +218,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
Classification evaluation = new Classification("act", "pred", null, Arrays.asList(metric1, metric2, metric3, metric4));
assertThat(metric1.getResult(), isEmpty());
assertThat(metric2.getResult(), isEmpty());
assertThat(metric3.getResult(), isEmpty());
@ -183,6 +287,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
private int currentStepIndex;
private EvaluationMetricResult result;
FakeClassificationMetric(String name) {
this(name, 1);
}
FakeClassificationMetric(String name, int numSteps) {
this.name = name;
this.numSteps = numSteps;
@ -198,10 +306,14 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
return name;
}
@Override
public Set<String> getRequiredFields() {
return Sets.newHashSet(EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_FIELD.getPreferredName());
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
EvaluationFields fields) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}

View File

@ -13,6 +13,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
@ -37,6 +38,7 @@ import static org.hamcrest.Matchers.not;
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
@Override
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
@ -83,7 +85,8 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testAggs() {
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred");
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs =
confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
assertThat(aggs, isTuple(not(empty()), empty()));
assertThat(confusionMatrix.getResult(), isEmpty());
}
@ -119,7 +122,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
@ -162,7 +165,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
@ -246,7 +249,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
confusionMatrix.process(aggsStep2);
confusionMatrix.process(aggsStep3);
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(confusionMatrix.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException;
@ -28,6 +29,7 @@ import static org.hamcrest.Matchers.equalTo;
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
@Override
protected Precision doParseInstance(XContentParser parser) throws IOException {
@ -64,7 +66,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
Precision precision = new Precision();
precision.process(aggs);
assertThat(precision.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(precision.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
}
@ -114,7 +116,7 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
Aggregations aggs =
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
Precision precision = new Precision();
precision.aggs(EVALUATION_PARAMETERS, "foo", "bar");
precision.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException;
@ -27,6 +28,7 @@ import static org.hamcrest.Matchers.equalTo;
public class RecallTests extends AbstractSerializingTestCase<Recall> {
private static final EvaluationParameters EVALUATION_PARAMETERS = new EvaluationParameters(100);
private static final EvaluationFields EVALUATION_FIELDS = new EvaluationFields("foo", "bar", null, null, null, true);
@Override
protected Recall doParseInstance(XContentParser parser) throws IOException {
@ -62,7 +64,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
Recall recall = new Recall();
recall.process(aggs);
assertThat(recall.aggs(EVALUATION_PARAMETERS, "act", "pred"), isTuple(empty(), empty()));
assertThat(recall.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS), isTuple(empty(), empty()));
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
}
@ -113,7 +115,7 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
Recall recall = new Recall();
recall.aggs(EVALUATION_PARAMETERS, "foo", "bar");
recall.aggs(EVALUATION_PARAMETERS, EVALUATION_FIELDS);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}

View File

@ -10,12 +10,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class AucRocTests extends AbstractSerializingTestCase<AucRoc> {
@ -37,91 +31,4 @@ public class AucRocTests extends AbstractSerializingTestCase<AucRoc> {
public static AucRoc createRandom() {
return new AucRoc(randomBoolean() ? randomBoolean() : null);
}
public void testCalculateAucScore_GivenZeroPercentiles() {
double[] tpPercentiles = zeroPercentiles();
double[] fpPercentiles = zeroPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
assertThat(aucRocScore, closeTo(0.5, 0.01));
}
public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() {
double[] tpPercentiles = randomPercentiles();
double[] fpPercentiles = zeroPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
assertThat(aucRocScore, closeTo(1.0, 0.1));
}
public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() {
double[] tpPercentiles = zeroPercentiles();
double[] fpPercentiles = randomPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
assertThat(aucRocScore, closeTo(0.0, 0.1));
}
public void testCalculateAucScore_GivenRandomPercentiles() {
for (int i = 0; i < 20; i++) {
double[] tpPercentiles = randomPercentiles();
double[] fpPercentiles = randomPercentiles();
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
assertThat(aucRocScore, greaterThanOrEqualTo(0.0));
assertThat(aucRocScore, lessThanOrEqualTo(1.0));
assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0));
assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0));
assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05));
}
}
public void testCalculateAucScore_GivenPrecalculated() {
double[] tpPercentiles = new double[99];
double[] fpPercentiles = new double[99];
double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 };
double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 };
for (int i = 0; i < tpPercentiles.length; i++) {
int simplifiedIndex = i / 25;
tpPercentiles[i] = tpSimplified[simplifiedIndex];
fpPercentiles[i] = fpSimplified[simplifiedIndex];
}
List<AucRoc.AucRocPoint> curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles);
double aucRocScore = AucRoc.calculateAucScore(curve);
List<AucRoc.AucRocPoint> inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles);
double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve);
assertThat(aucRocScore, closeTo(0.8, 0.05));
assertThat(inverseAucRocScore, closeTo(0.2, 0.05));
}
public static double[] zeroPercentiles() {
double[] percentiles = new double[99];
Arrays.fill(percentiles, 0.0);
return percentiles;
}
public static double[] randomPercentiles() {
double[] percentiles = new double[99];
for (int i = 0; i < percentiles.length; i++) {
percentiles[i] = randomDouble();
}
Arrays.sort(percentiles);
return percentiles;
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
@ -26,6 +27,8 @@ import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
@ -86,6 +89,17 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
assertThat(e.getMessage(), equalTo("[outlier_detection] must have one or more metrics"));
}
public void testGetFields() {
OutlierDetection evaluation = new OutlierDetection("foo", "bar", null);
EvaluationFields fields = evaluation.getFields();
assertThat(fields.getActualField(), is(equalTo("foo")));
assertThat(fields.getPredictedField(), is(nullValue()));
assertThat(fields.getTopClassesField(), is(nullValue()));
assertThat(fields.getPredictedClassField(), is(nullValue()));
assertThat(fields.getPredictedProbabilityField(), is(equalTo("bar")));
assertThat(fields.isPredictedProbabilityFieldNested(), is(false));
}
public void testBuildSearch() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()

View File

@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
@ -26,6 +27,8 @@ import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@ -73,6 +76,17 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
}
public void testGetFields() {
Regression evaluation = new Regression("foo", "bar", null);
EvaluationFields fields = evaluation.getFields();
assertThat(fields.getActualField(), is(equalTo("foo")));
assertThat(fields.getPredictedField(), is(equalTo("bar")));
assertThat(fields.getTopClassesField(), is(nullValue()));
assertThat(fields.getPredictedClassField(), is(nullValue()));
assertThat(fields.getPredictedProbabilityField(), is(nullValue()));
assertThat(fields.isPredictedProbabilityFieldNested(), is(false));
}
public void testBuildSearch() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()

View File

@ -115,6 +115,11 @@ yamlRestTest {
'ml/evaluate_data_frame/Test classification given evaluation with empty metrics',
'ml/evaluate_data_frame/Test classification given missing actual_field',
'ml/evaluate_data_frame/Test classification given missing predicted_field',
'ml/evaluate_data_frame/Test classification given missing top_classes_field',
'ml/evaluate_data_frame/Test classification auc_roc given actual_field is never equal to fish',
'ml/evaluate_data_frame/Test classification auc_roc given predicted_class_field is never equal to mouse',
'ml/evaluate_data_frame/Test classification auc_roc with missing class_name',
'ml/evaluate_data_frame/Test classification accuracy with missing predicted_field',
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
'ml/evaluate_data_frame/Test regression given missing actual_field',
'ml/evaluate_data_frame/Test regression given missing predicted_field',

View File

@ -16,6 +16,7 @@ import org.elasticsearch.search.aggregations.MultiBucketConsumerService.TooManyB
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
@ -24,13 +25,17 @@ import org.junit.After;
import org.junit.Before;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.stream.IntStream;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
@ -38,16 +43,19 @@ import static org.hamcrest.Matchers.notANumber;
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
private static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
private static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction";
private static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
private static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
private static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction";
private static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
private static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
private static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction";
static final String ANIMAL_NAME_KEYWORD_FIELD = "animal_name_keyword";
static final String ANIMAL_NAME_PREDICTION_KEYWORD_FIELD = "animal_name_keyword_prediction";
static final String ANIMAL_NAME_PREDICTION_PROB_FIELD = "animal_name_prediction_prob";
static final String NO_LEGS_KEYWORD_FIELD = "no_legs_keyword";
static final String NO_LEGS_INTEGER_FIELD = "no_legs_integer";
static final String NO_LEGS_PREDICTION_INTEGER_FIELD = "no_legs_integer_prediction";
static final String IS_PREDATOR_KEYWORD_FIELD = "predator_keyword";
static final String IS_PREDATOR_BOOLEAN_FIELD = "predator_boolean";
static final String IS_PREDATOR_PREDICTION_BOOLEAN_FIELD = "predator_boolean_prediction";
static final String IS_PREDATOR_PREDICTION_PROBABILITY_FIELD = "predator_prediction_probability";
static final String ML_TOP_CLASSES_FIELD = "ml_results";
@Before
public void setup() {
@ -67,7 +75,8 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
public void testEvaluate_DefaultMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null));
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(
@ -82,6 +91,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
new Classification(
ANIMAL_NAME_KEYWORD_FIELD,
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
null,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
@ -116,6 +126,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
new Classification(
actualField,
predictedField,
null,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
@ -139,9 +150,37 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(recallResult.getAvgRecall(), equalTo(0.0));
}
private AucRoc.Result evaluateAucRoc(boolean includeCurve) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, null, ML_TOP_CLASSES_FIELD, Arrays.asList(new AucRoc(includeCurve, "cat"))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
AucRoc.Result aucrocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(aucrocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
return aucrocResult;
}
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(false);
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
assertThat(aucrocResult.getCurve(), hasSize(0));
}
public void testEvaluate_AucRoc_IncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(true);
assertThat(aucrocResult.getScore(), is(closeTo(0.5, 0.0001)));
assertThat(aucrocResult.getDocCount(), is(equalTo(75L)));
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
}
private Accuracy.Result evaluateAccuracy(String actualField, String predictedField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Accuracy())));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -260,7 +299,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
private Precision.Result evaluatePrecision(String actualField, String predictedField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Precision())));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -354,13 +393,13 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Precision()))));
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, Arrays.asList(new Precision()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
}
private Recall.Result evaluateRecall(String actualField, String predictedField) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, Arrays.asList(new Recall())));
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(actualField, predictedField, null, Arrays.asList(new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -469,7 +508,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new Recall()))));
new Classification(ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, null, Arrays.asList(new Recall()))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name_keyword] is too high"));
}
@ -478,7 +517,10 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_KEYWORD_FIELD, ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
ANIMAL_NAME_KEYWORD_FIELD,
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
null,
Arrays.asList(new MulticlassConfusionMatrix())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
@ -561,6 +603,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
new Classification(
ANIMAL_NAME_KEYWORD_FIELD,
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD,
null,
Arrays.asList(new MulticlassConfusionMatrix(3, null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
@ -595,7 +638,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
}
private static void createAnimalsIndex(String indexName) {
static void createAnimalsIndex(String indexName) {
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc",
ANIMAL_NAME_KEYWORD_FIELD, "type=keyword",
@ -605,28 +648,41 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
NO_LEGS_PREDICTION_INTEGER_FIELD, "type=integer",
IS_PREDATOR_KEYWORD_FIELD, "type=keyword",
IS_PREDATOR_BOOLEAN_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean")
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, "type=boolean",
IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, "type=double",
ML_TOP_CLASSES_FIELD, "type=nested")
.get();
}
private static void indexAnimalsData(String indexName) {
static void indexAnimalsData(String indexName) {
List<String> animalNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < animalNames.size(); i++) {
for (int j = 0; j < animalNames.size(); j++) {
for (int k = 0; k < j + 1; k++) {
List<?> topClasses =
IntStream
.range(0, 5)
.mapToObj(ix -> new HashMap<String, Object>() {{
put("class_name", animalNames.get(ix));
put("class_probability", 0.4 - 0.1 * ix);
}})
.collect(toList());
bulkRequestBuilder.add(
new IndexRequest(indexName)
.source(
ANIMAL_NAME_KEYWORD_FIELD, animalNames.get(i),
ANIMAL_NAME_PREDICTION_KEYWORD_FIELD, animalNames.get((i + j) % animalNames.size()),
ANIMAL_NAME_PREDICTION_PROB_FIELD, animalNames.get((i + j) % animalNames.size()),
NO_LEGS_KEYWORD_FIELD, String.valueOf(i + 1),
NO_LEGS_INTEGER_FIELD, i + 1,
NO_LEGS_PREDICTION_INTEGER_FIELD, j + 1,
IS_PREDATOR_KEYWORD_FIELD, String.valueOf(i % 2 == 0),
IS_PREDATOR_BOOLEAN_FIELD, i % 2 == 0,
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0));
IS_PREDATOR_PREDICTION_BOOLEAN_FIELD, (i + j) % 2 == 0,
IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, i % 2 == 0 ? 1.0 - 0.1 * i : 0.1 * i,
ML_TOP_CLASSES_FIELD, topClasses));
}
}
}

View File

@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
@ -957,9 +958,15 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
dependentVariable,
predictedClassField,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
null,
Arrays.asList(
new Accuracy(),
new AucRoc(true, dependentVariableValues.get(0).toString()),
new MulticlassConfusionMatrix(),
new Precision(),
new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(5));
{ // Accuracy
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
@ -970,9 +977,17 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
}
}
{ // AucRoc
AucRoc.Result aucRocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(1);
assertThat(aucRocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
assertThat(aucRocResult.getScore(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
assertThat(aucRocResult.getDocCount(), allOf(greaterThanOrEqualTo(1L), lessThanOrEqualTo(350L)));
assertThat(aucRocResult.getCurve(), hasSize(greaterThan(0)));
}
{ // MulticlassConfusionMatrix
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(1);
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(2);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
List<MulticlassConfusionMatrix.ActualClass> actualClasses = confusionMatrixResult.getConfusionMatrix();
assertThat(
@ -990,7 +1005,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
}
{ // Precision
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(3);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
for (Precision.PerClassResult klass : precisionResult.getClasses()) {
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
@ -999,7 +1014,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
}
{ // Recall
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(4);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
for (Recall.PerClassResult klass : recallResult.getClasses()) {
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));

View File

@ -0,0 +1,114 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.ConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.Recall;
import org.junit.After;
import org.junit.Before;
import java.util.Arrays;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.ANIMALS_DATA_INDEX;
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.IS_PREDATOR_BOOLEAN_FIELD;
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.IS_PREDATOR_PREDICTION_PROBABILITY_FIELD;
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.createAnimalsIndex;
import static org.elasticsearch.xpack.ml.integration.ClassificationEvaluationIT.indexAnimalsData;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
public class OutlierDetectionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@Before
public void setup() {
createAnimalsIndex(ANIMALS_DATA_INDEX);
indexAnimalsData(ANIMALS_DATA_INDEX);
}
@After
public void cleanup() {
cleanUp();
}
public void testEvaluate_DefaultMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new OutlierDetection(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName()));
assertThat(
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
containsInAnyOrder(
AucRoc.NAME.getPreferredName(),
Precision.NAME.getPreferredName(),
Recall.NAME.getPreferredName(),
ConfusionMatrix.NAME.getPreferredName()));
}
public void testEvaluate_AllMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new OutlierDetection(
IS_PREDATOR_BOOLEAN_FIELD,
IS_PREDATOR_PREDICTION_PROBABILITY_FIELD,
Arrays.asList(
new AucRoc(false),
new Precision(Arrays.asList(0.5)),
new Recall(Arrays.asList(0.5)),
new ConfusionMatrix(Arrays.asList(0.5)))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName()));
assertThat(
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
containsInAnyOrder(
AucRoc.NAME.getPreferredName(),
Precision.NAME.getPreferredName(),
Recall.NAME.getPreferredName(),
ConfusionMatrix.NAME.getPreferredName()));
}
private AucRoc.Result evaluateAucRoc(String actualField, String predictedField, boolean includeCurve) {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new OutlierDetection(actualField, predictedField, Arrays.asList(new AucRoc(includeCurve))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
AucRoc.Result aucrocResult = (AucRoc.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(aucrocResult.getMetricName(), equalTo(AucRoc.NAME.getPreferredName()));
return aucrocResult;
}
public void testEvaluate_AucRoc_DoNotIncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, false);
assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001)));
assertThat(aucrocResult.getCurve(), hasSize(0));
}
public void testEvaluate_AucRoc_IncludeCurve() {
AucRoc.Result aucrocResult = evaluateAucRoc(IS_PREDATOR_BOOLEAN_FIELD, IS_PREDATOR_PREDICTION_PROBABILITY_FIELD, true);
assertThat(aucrocResult.getScore(), is(closeTo(1.0, 0.0001)));
assertThat(aucrocResult.getCurve(), hasSize(greaterThan(0)));
}
@Override
boolean supportsInference() {
return false;
}
}

View File

@ -206,25 +206,25 @@ public class DestinationIndexTests extends ESTestCase {
public void testCreateDestinationIndex_Classification() throws IOException {
Map<String, Object> map = testCreateDestinationIndex(new Classification(NUMERICAL_FIELD));
assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testCreateDestinationIndex_Classification_DependentVariableIsNested() throws IOException {
Map<String, Object> map = testCreateDestinationIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD));
assertThat(extractValue("_doc.properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testCreateDestinationIndex_Classification_DependentVariableIsAlias() throws IOException {
Map<String, Object> map = testCreateDestinationIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD));
assertThat(extractValue("_doc.properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testCreateDestinationIndex_Classification_DependentVariableIsAliasToNested() throws IOException {
Map<String, Object> map = testCreateDestinationIndex(new Classification(ALIAS_TO_NESTED_FIELD));
assertThat(extractValue("_doc.properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("_doc.properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testCreateDestinationIndex_ResultsFieldsExistsInSourceIndex() throws IOException {
@ -322,25 +322,25 @@ public class DestinationIndexTests extends ESTestCase {
public void testUpdateMappingsToDestIndex_Classification() throws IOException {
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(NUMERICAL_FIELD));
assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsNested() throws IOException {
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(OUTER_FIELD + "." + INNER_FIELD));
assertThat(extractValue("properties.ml.outer-field.inner-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAlias() throws IOException {
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NUMERICAL_FIELD));
assertThat(extractValue("properties.ml.alias-to-numerical-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testUpdateMappingsToDestIndex_Classification_DependentVariableIsAliasToNested() throws IOException {
Map<String, Object> map = testUpdateMappingsToDestIndex(new Classification(ALIAS_TO_NESTED_FIELD));
assertThat(extractValue("properties.ml.alias-to-nested-field_prediction.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer"));
assertThat(extractValue("properties.ml.top_classes.properties.class_name.type", map), equalTo("integer"));
}
public void testUpdateMappingsToDestIndex_ResultsFieldsExistsInSourceIndex() throws IOException {

View File

@ -1,5 +1,14 @@
setup:
- do:
indices.create:
index: utopia
body:
mappings:
properties:
ml.top_classes:
type: nested
- do:
index:
index: utopia
@ -14,7 +23,11 @@ setup:
"classification_field_act": "dog",
"classification_field_pred": "dog",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "dog", "class_probability": 0.9},
{"class_name": "cat", "class_probability": 0.1}
]
}
- do:
@ -31,7 +44,11 @@ setup:
"classification_field_act": "cat",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "cat", "class_probability": 0.8},
{"class_name": "dog", "class_probability": 0.2}
]
}
- do:
@ -48,7 +65,11 @@ setup:
"classification_field_act": "mouse",
"classification_field_pred": "mouse",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "cat", "class_probability": 0.3},
{"class_name": "dog", "class_probability": 0.1}
]
}
- do:
@ -65,7 +86,11 @@ setup:
"classification_field_act": "dog",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "cat", "class_probability": 0.6},
{"class_name": "dog", "class_probability": 0.3}
]
}
- do:
@ -82,7 +107,11 @@ setup:
"classification_field_act": "cat",
"classification_field_pred": "dog",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "dog", "class_probability": 0.7},
{"class_name": "cat", "class_probability": 0.3}
]
}
- do:
@ -99,7 +128,11 @@ setup:
"classification_field_act": "dog",
"classification_field_pred": "dog",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "dog", "class_probability": 0.9},
{"class_name": "cat", "class_probability": 0.1}
]
}
- do:
@ -116,7 +149,11 @@ setup:
"classification_field_act": "cat",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "cat", "class_probability": 0.8},
{"class_name": "dog", "class_probability": 0.2}
]
}
- do:
@ -133,7 +170,11 @@ setup:
"classification_field_act": "mouse",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
"all_false_field": false,
"ml.top_classes": [
{"class_name": "cat", "class_probability": 0.8},
{"class_name": "dog", "class_probability": 0.2}
]
}
# This document misses the required fields and should be ignored
@ -166,6 +207,7 @@ setup:
}
}
- match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_false: outlier_detection.auc_roc.curve
---
@ -186,6 +228,7 @@ setup:
}
}
- match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_false: outlier_detection.auc_roc.curve
---
@ -206,12 +249,13 @@ setup:
}
}
- match: { outlier_detection.auc_roc.score: 0.9899 }
- match: { outlier_detection.auc_roc.doc_count: 8 }
- is_true: outlier_detection.auc_roc.curve
---
"Test outlier_detection auc_roc given actual_field is always true":
- do:
catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/
catch: /\[auc_roc\] requires at least one \[all_true_field\] to have a different value than \[true\]/
ml.evaluate_data_frame:
body: >
{
@ -230,7 +274,7 @@ setup:
---
"Test outlier_detection auc_roc given actual_field is always false":
- do:
catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/
catch: /\[auc_roc\] requires at least one \[all_false_field\] to have the value \[true\]/
ml.evaluate_data_frame:
body: >
{
@ -371,6 +415,7 @@ setup:
}
}
- is_true: outlier_detection.auc_roc.score
- is_true: outlier_detection.auc_roc.doc_count
- is_true: outlier_detection.precision.0\.25
- is_true: outlier_detection.precision.0\.5
- is_true: outlier_detection.precision.0\.75
@ -443,7 +488,7 @@ setup:
---
"Test outlier_detection given missing actual_field":
- do:
catch: /No documents found containing both \[missing, outlier_score\] fields/
catch: /No documents found containing all the required fields \[missing, outlier_score\]/
ml.evaluate_data_frame:
body: >
{
@ -459,7 +504,7 @@ setup:
---
"Test outlier_detection given missing predicted_probability_field":
- do:
catch: /No documents found containing both \[is_outlier, missing\] fields/
catch: /No documents found containing all the required fields \[is_outlier, missing\]/
ml.evaluate_data_frame:
body: >
{
@ -598,7 +643,124 @@ setup:
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { }
"metrics": {}
}
}
}
---
"Test classification auc_roc with missing class_name":
- do:
# TODO: Revisit this error message as it does not give any clue about which field is missing
catch: /Failed to build \[auc_roc\] after last required field arrived/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"top_classes_field": "ml.top_classes",
"metrics": {
"auc_roc": {}
}
}
}
}
---
"Test classification auc_roc given actual_field is never equal to fish":
- do:
catch: /\[auc_roc\] requires at least one \[classification_field_act.keyword\] to have the value \[fish\]/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"top_classes_field": "ml.top_classes",
"metrics": {
"auc_roc": {
"class_name": "fish"
}
}
}
}
}
---
"Test classification auc_roc given predicted_class_field is never equal to mouse":
- do:
catch: /\[auc_roc\] requires at least one \[ml.top_classes.class_name\] to have the value \[mouse\]/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"top_classes_field": "ml.top_classes",
"metrics": {
"auc_roc": {
"class_name": "mouse"
}
}
}
}
}
---
"Test classification auc_roc":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"top_classes_field": "ml.top_classes",
"metrics": {
"auc_roc": {
"class_name": "cat"
}
}
}
}
}
- match: { classification.auc_roc.score: 0.8050111095212122 }
- match: { classification.auc_roc.doc_count: 8 }
- is_false: classification.auc_roc.curve
---
"Test classification auc_roc with default top_classes_field":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"metrics": {
"auc_roc": {
"class_name": "cat"
}
}
}
}
}
- match: { classification.auc_roc.score: 0.8050111095212122 }
- match: { classification.auc_roc.doc_count: 8 }
- is_false: classification.auc_roc.curve
---
"Test classification accuracy with missing predicted_field":
- do:
catch: /\[classification\] must define \[predicted_field\] as required by the following metrics \[accuracy\]/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"metrics": { "accuracy": {} }
}
}
}
@ -785,7 +947,7 @@ setup:
---
"Test classification given missing actual_field":
- do:
catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/
catch: /No documents found containing all the required fields \[missing, classification_field_pred.keyword\]/
ml.evaluate_data_frame:
body: >
{
@ -801,7 +963,7 @@ setup:
---
"Test classification given missing predicted_field":
- do:
catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/
catch: /No documents found containing all the required fields \[classification_field_act.keyword, missing\]/
ml.evaluate_data_frame:
body: >
{
@ -815,6 +977,27 @@ setup:
}
---
"Test classification given missing top_classes_field":
- do:
catch: /No documents found containing all the required fields \[classification_field_act.keyword, missing.class_name, missing.class_probability\]/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"top_classes_field": "missing",
"metrics": {
"auc_roc": {
"class_name": "dummy"
}
}
}
}
}
---
"Test regression given evaluation with empty metrics":
- do:
catch: /\[regression\] must have one or more metrics/
@ -932,7 +1115,7 @@ setup:
---
"Test regression given missing actual_field":
- do:
catch: /No documents found containing both \[missing, regression_field_pred\] fields/
catch: /No documents found containing all the required fields \[missing, regression_field_pred\]/
ml.evaluate_data_frame:
body: >
{
@ -948,7 +1131,7 @@ setup:
---
"Test regression given missing predicted_field":
- do:
catch: /No documents found containing both \[regression_field_act, missing\] fields/
catch: /No documents found containing all the required fields \[regression_field_act, missing\]/
ml.evaluate_data_frame:
body: >
{