[7.x] Fix accuracy metric (#50310) (#50433)

This commit is contained in:
Przemysław Witek 2019-12-20 15:34:38 +01:00 committed by GitHub
parent 14d95aae46
commit 3e3a93002f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 476 additions and 293 deletions

View File

@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
@ -35,10 +36,25 @@ import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* {@link AccuracyMetric} is a metric that answers the question:
* "What fraction of examples have been classified correctly by the classifier?"
* {@link AccuracyMetric} is a metric that answers the following two questions:
*
* equation: accuracy = 1/n * Σ(y == y´)
* 1. What is the fraction of documents for which predicted class equals the actual class?
*
* equation: overall_accuracy = 1/n * Σ(y == y')
* where: n = total number of documents
* y = document's actual class
* y' = document's predicted class
*
* 2. For any given class X, what is the fraction of documents for which either
* a) both actual and predicted class are equal to X (true positives)
* or
* b) both actual and predicted class are not equal to X (true negatives)
*
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
* where: X = class being examined
* n = total number of documents
* TP(X) = number of true positives wrt X
* TN(X) = number of true negatives wrt X
*/
public class AccuracyMetric implements EvaluationMetric {
@ -78,15 +94,15 @@ public class AccuracyMetric implements EvaluationMetric {
public static class Result implements EvaluationMetric.Result {
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}
@ -94,13 +110,13 @@ public class AccuracyMetric implements EvaluationMetric {
return PARSER.apply(parser, null);
}
/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Fraction of documents predicted correctly. */
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy;
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
public Result(List<PerClassResult> classes, double overallAccuracy) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.overallAccuracy = overallAccuracy;
}
@ -109,8 +125,8 @@ public class AccuracyMetric implements EvaluationMetric {
return NAME;
}
public List<ActualClass> getActualClasses() {
return actualClasses;
public List<PerClassResult> getClasses() {
return classes;
}
public double getOverallAccuracy() {
@ -120,7 +136,7 @@ public class AccuracyMetric implements EvaluationMetric {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
builder.field(CLASSES.getPreferredName(), classes);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject();
return builder;
@ -131,52 +147,42 @@ public class AccuracyMetric implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses)
return Objects.equals(this.classes, that.classes)
&& this.overallAccuracy == that.overallAccuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy);
return Objects.hash(classes, overallAccuracy);
}
}
public static class ActualClass implements ToXContentObject {
public static class PerClassResult implements ToXContentObject {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACCURACY = new ParseField("accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), ACCURACY);
}
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
/** Name of the class. */
private final String className;
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final double accuracy;
public ActualClass(
String actualClass, long actualClassDocCount, double accuracy) {
this.actualClass = Objects.requireNonNull(actualClass);
this.actualClassDocCount = actualClassDocCount;
public PerClassResult(String className, double accuracy) {
this.className = Objects.requireNonNull(className);
this.accuracy = accuracy;
}
public String getActualClass() {
return actualClass;
}
public long getActualClassDocCount() {
return actualClassDocCount;
public String getClassName() {
return className;
}
public double getAccuracy() {
@ -186,8 +192,7 @@ public class AccuracyMetric implements EvaluationMetric {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
@ -197,15 +202,19 @@ public class AccuracyMetric implements EvaluationMetric {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.accuracy == that.accuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy);
return Objects.hash(className, accuracy);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
}

View File

@ -1849,15 +1849,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(
accuracyResult.getActualClasses(),
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("cat", 5, 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75),
// no examples labeled as "ant" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
// 9 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("ant", 0.9),
// 6 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("cat", 0.6),
// 8 out of 10 examples were classified correctly
new AccuracyMetric.PerClassResult("dog", 0.8))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
}
{ // Precision

View File

@ -19,7 +19,7 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
@ -41,13 +41,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result>
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
classes.add(new PerClassResult(classNames.get(i), accuracy));
}
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy);
return new Result(classes, overallAccuracy);
}
@Override

View File

@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
* Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
*/
Optional<EvaluationMetricResult> getResult();
Optional<? extends EvaluationMetricResult> getResult();
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -29,7 +29,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
@ -40,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* {@link Accuracy} is a metric that answers the question:
* "What fraction of examples have been classified correctly by the classifier?"
* {@link Accuracy} is a metric that answers the following two questions:
*
* equation: accuracy = 1/n * Σ(y == y´)
* 1. What is the fraction of documents for which predicted class equals the actual class?
*
* equation: overall_accuracy = 1/n * Σ(y == y')
* where: n = total number of documents
* y = document's actual class
* y' = document's predicted class
*
* 2. For any given class X, what is the fraction of documents for which either
* a) both actual and predicted class are equal to X (true positives)
* or
* b) both actual and predicted class are not equal to X (true negatives)
*
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
* where: X = class being examined
* n = total number of documents
* TP(X) = number of true positives wrt X
* TN(X) = number of true negatives wrt X
*/
public class Accuracy implements EvaluationMetric {
public static final ParseField NAME = new ParseField("accuracy");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String CLASSES_AGG_NAME = "classification_classes";
private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
@ -64,11 +77,20 @@ public class Accuracy implements EvaluationMetric {
return PARSER.apply(parser, null);
}
private EvaluationMetricResult result;
private static final int MAX_CLASSES_CARDINALITY = 1000;
public Accuracy() {}
private final MulticlassConfusionMatrix matrix;
private final SetOnce<String> actualField = new SetOnce<>();
private final SetOnce<Double> overallAccuracy = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();
public Accuracy(StreamInput in) throws IOException {}
public Accuracy() {
this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
}
public Accuracy(StreamInput in) throws IOException {
this.matrix = new MulticlassConfusionMatrix(in);
}
@Override
public String getWriteableName() {
@ -82,43 +104,79 @@ public class Accuracy implements EvaluationMetric {
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField.trySet(actualField);
List<AggregationBuilder> aggs = new ArrayList<>();
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
if (overallAccuracy.get() == null) {
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
}
Script accuracyScript = new Script(buildScript(actualField, predictedField));
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(CLASSES_AGG_NAME)
.field(actualField)
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
Collections.emptyList());
if (result.get() == null) {
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
aggs.addAll(matrixAggs.v1());
pipelineAggs.addAll(matrixAggs.v2());
}
return Tuple.tuple(aggs, pipelineAggs);
}
@Override
public void process(Aggregations aggs) {
if (result != null) {
return;
}
Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
for (Terms.Bucket bucket : classesAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString();
long actualClassDocCount = bucket.getDocCount();
NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
overallAccuracy.set(overallAccuracyAgg.value());
}
matrix.process(aggs);
if (result.get() == null && matrix.getResult().isPresent()) {
if (matrix.getResult().get().getOtherActualClassCount() > 0) {
// This means there were more than {@code maxClassesCardinality} buckets.
// We cannot calculate per-class accuracy accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get());
}
result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get()));
}
result = new Result(actualClasses, overallAccuracyAgg.value());
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
public Optional<Result> getResult() {
return Optional.ofNullable(result.get());
}
/**
* Computes the per-class accuracy results based on multiclass confusion matrix's result.
* Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension.
* This method is visible for testing only.
*/
static List<PerClassResult> computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) {
assert matrixResult.getOtherActualClassCount() == 0;
// Number of actual classes taken into account
int n = matrixResult.getConfusionMatrix().size();
// Total number of documents taken into account
long totalDocCount =
matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum();
List<PerClassResult> classes = new ArrayList<>(n);
for (int i = 0; i < n; ++i) {
String className = matrixResult.getConfusionMatrix().get(i).getActualClass();
// Start with the assumption that all the docs were predicted correctly.
long correctDocCount = totalDocCount;
for (int j = 0; j < n; ++j) {
if (i != j) {
// Subtract errors (false negatives)
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount();
// Subtract errors (false positives)
correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount();
}
}
// Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount();
classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount));
}
return classes;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
matrix.writeTo(out);
}
@Override
@ -132,25 +190,26 @@ public class Accuracy implements EvaluationMetric {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
Accuracy that = (Accuracy) o;
return Objects.equals(this.matrix, that.matrix);
}
@Override
public int hashCode() {
return Objects.hashCode(NAME.getPreferredName());
return Objects.hash(matrix);
}
public static class Result implements EvaluationMetricResult {
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}
@ -158,18 +217,18 @@ public class Accuracy implements EvaluationMetric {
return PARSER.apply(parser, null);
}
/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Fraction of documents predicted correctly. */
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy;
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
public Result(List<PerClassResult> classes, double overallAccuracy) {
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
this.overallAccuracy = overallAccuracy;
}
public Result(StreamInput in) throws IOException {
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
this.overallAccuracy = in.readDouble();
}
@ -183,8 +242,8 @@ public class Accuracy implements EvaluationMetric {
return NAME.getPreferredName();
}
public List<ActualClass> getActualClasses() {
return actualClasses;
public List<PerClassResult> getClasses() {
return classes;
}
public double getOverallAccuracy() {
@ -193,14 +252,14 @@ public class Accuracy implements EvaluationMetric {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(actualClasses);
out.writeList(classes);
out.writeDouble(overallAccuracy);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
builder.field(CLASSES.getPreferredName(), classes);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject();
return builder;
@ -211,54 +270,47 @@ public class Accuracy implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses)
return Objects.equals(this.classes, that.classes)
&& this.overallAccuracy == that.overallAccuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy);
return Objects.hash(classes, overallAccuracy);
}
}
public static class ActualClass implements ToXContentObject, Writeable {
public static class PerClassResult implements ToXContentObject, Writeable {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACCURACY = new ParseField("accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), ACCURACY);
}
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
/** Name of the class. */
private final String className;
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final double accuracy;
public ActualClass(
String actualClass, long actualClassDocCount, double accuracy) {
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
this.actualClassDocCount = actualClassDocCount;
public PerClassResult(String className, double accuracy) {
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
this.accuracy = accuracy;
}
public ActualClass(StreamInput in) throws IOException {
this.actualClass = in.readString();
this.actualClassDocCount = in.readVLong();
public PerClassResult(StreamInput in) throws IOException {
this.className = in.readString();
this.accuracy = in.readDouble();
}
public String getActualClass() {
return actualClass;
public String getClassName() {
return className;
}
public double getAccuracy() {
@ -267,16 +319,14 @@ public class Accuracy implements EvaluationMetric {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualClass);
out.writeVLong(actualClassDocCount);
out.writeString(className);
out.writeDouble(accuracy);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
@ -286,15 +336,14 @@ public class Accuracy implements EvaluationMetric {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.accuracy == that.accuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy);
return Objects.hash(className, accuracy);
}
}
}

View File

@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
@ -53,13 +55,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
public static final ParseField SIZE = new ParseField("size");
public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix");
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
new ConstructingObjectParser<>(
NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
parser.declareInt(optionalConstructorArg(), SIZE);
parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX);
return parser;
}
@ -67,31 +72,39 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return PARSER.apply(parser, null);
}
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
private static final String OTHER_BUCKET_KEY = "_other_";
private static final String DEFAULT_AGG_NAME_PREFIX = "";
private static final int DEFAULT_SIZE = 10;
private static final int MAX_SIZE = 1000;
private final int size;
private List<String> topActualClassNames;
private Result result;
private final String aggNamePrefix;
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();
public MulticlassConfusionMatrix() {
this((Integer) null);
this(null, null);
}
public MulticlassConfusionMatrix(@Nullable Integer size) {
public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) {
if (size != null && (size <= 0 || size > MAX_SIZE)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
}
this.size = size != null ? size : DEFAULT_SIZE;
this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX;
}
public MulticlassConfusionMatrix(StreamInput in) throws IOException {
this.size = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
this.aggNamePrefix = in.readString();
} else {
this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX;
}
}
@Override
@ -110,30 +123,30 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (topActualClassNames == null) { // This is step 1
if (topActualClassNames.get() == null) { // This is step 1
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)),
Collections.emptyList());
}
if (result == null) { // This is step 2
if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersActual =
topActualClassNames.stream()
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream()
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
.field(actualField),
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))),
Collections.emptyList());
@ -143,18 +156,18 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
@Override
public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
}
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString();
long actualClassDocCount = bucket.getDocCount();
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS));
List<PredictedClass> predictedClasses = new ArrayList<>();
long otherPredictedClassDocCount = 0;
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
@ -169,18 +182,25 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
}
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
}
}
private String aggName(String aggNameWithoutPrefix) {
return aggNamePrefix + aggNameWithoutPrefix;
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
public Optional<Result> getResult() {
return Optional.ofNullable(result.get());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(size);
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeString(aggNamePrefix);
}
}
@Override
@ -196,12 +216,13 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
return Objects.equals(this.size, that.size);
return this.size == that.size
&& Objects.equals(this.aggNamePrefix, that.aggNamePrefix);
}
@Override
public int hashCode() {
return Objects.hash(size);
return Objects.hash(size, aggNamePrefix);
}
public static class Result implements EvaluationMetricResult {
@ -335,6 +356,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return actualClass;
}
public long getActualClassDocCount() {
return actualClassDocCount;
}
public List<PredictedClass> getPredictedClasses() {
return predictedClasses;
}
@ -411,6 +436,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return predictedClass;
}
public long getCount() {
return count;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(predictedClass);

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
@ -77,9 +78,9 @@ public class Precision implements EvaluationMetric {
private static final int MAX_CLASSES_CARDINALITY = 1000;
private String actualField;
private List<String> topActualClassNames;
private EvaluationMetricResult result;
private final SetOnce<String> actualField = new SetOnce<>();
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();
public Precision() {}
@ -98,8 +99,8 @@ public class Precision implements EvaluationMetric {
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField;
if (topActualClassNames == null) { // This is step 1
this.actualField.trySet(actualField);
if (topActualClassNames.get() == null) { // This is step 1
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
@ -108,9 +109,9 @@ public class Precision implements EvaluationMetric {
.size(MAX_CLASSES_CARDINALITY)),
Collections.emptyList());
}
if (result == null) { // This is step 2
if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream()
topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
Script script = buildScript(actualField, predictedField);
@ -128,18 +129,18 @@ public class Precision implements EvaluationMetric {
@Override
public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
// We cannot calculate average precision accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get());
}
topActualClassNames =
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
topActualClassNames.set(
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
}
if (result == null &&
if (result.get() == null &&
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
@ -153,13 +154,13 @@ public class Precision implements EvaluationMetric {
classes.add(new PerClassResult(className, precision));
}
}
result = new Result(classes, avgPrecisionAgg.value());
result.set(new Result(classes, avgPrecisionAgg.value()));
}
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
public Optional<Result> getResult() {
return Optional.ofNullable(result.get());
}
@Override

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
@ -71,8 +72,8 @@ public class Recall implements EvaluationMetric {
private static final int MAX_CLASSES_CARDINALITY = 1000;
private String actualField;
private EvaluationMetricResult result;
private final SetOnce<String> actualField = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();
public Recall() {}
@ -91,8 +92,8 @@ public class Recall implements EvaluationMetric {
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField;
if (result != null) {
this.actualField.trySet(actualField);
if (result.get() != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
Script script = buildScript(actualField, predictedField);
@ -110,7 +111,7 @@ public class Recall implements EvaluationMetric {
@Override
public void process(Aggregations aggs) {
if (result == null &&
if (result.get() == null &&
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
@ -118,7 +119,7 @@ public class Recall implements EvaluationMetric {
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
// We cannot calculate average recall accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get());
}
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
@ -127,13 +128,13 @@ public class Recall implements EvaluationMetric {
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
classes.add(new PerClassResult(className, recallAgg.value()));
}
result = new Result(classes, avgRecallAgg.value());
result.set(new Result(classes, avgRecallAgg.value()));
}
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
public Optional<Result> getResult() {
return Optional.ofNullable(result.get());
}
@Override

View File

@ -8,9 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import java.util.ArrayList;
import java.util.List;
@ -22,13 +22,13 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result>
public static Result createRandom() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
classes.add(new PerClassResult(classNames.get(i), accuracy));
}
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy);
return new Result(classes, overallAccuracy);
}
@Override

View File

@ -5,17 +5,27 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@ -46,52 +56,114 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms("classification_classes"),
mockSingleValue("classification_overall_accuracy", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
mockTerms(
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L),
mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy();
accuracy.process(aggs);
assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123)));
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
Result result = accuracy.getResult().get();
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
result.getClasses(),
equalTo(
Arrays.asList(
new PerClassResult("dog", 0.5),
new PerClassResult("cat", 0.5))));
assertThat(result.getOverallAccuracy(), equalTo(0.5));
}
public void testProcess_GivenMissingAgg() {
{
public void testProcess_GivenCardinalityTooHigh() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms("classification_classes"),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
mockTerms(
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L),
mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("classification_overall_accuracy", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
}
accuracy.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}
public void testProcess_GivenAggOfWrongType() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms("classification_classes"),
mockTerms("classification_overall_accuracy")
));
Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("classification_classes", 1.0),
mockSingleValue("classification_overall_accuracy", 0.8123)
));
Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
public void testComputePerClassAccuracy() {
assertThat(
Accuracy.computePerClassAccuracy(
new MulticlassConfusionMatrix.Result(
Arrays.asList(
new MulticlassConfusionMatrix.ActualClass("A", 14, Arrays.asList(
new MulticlassConfusionMatrix.PredictedClass("A", 1),
new MulticlassConfusionMatrix.PredictedClass("B", 6),
new MulticlassConfusionMatrix.PredictedClass("C", 4)
), 3L),
new MulticlassConfusionMatrix.ActualClass("B", 20, Arrays.asList(
new MulticlassConfusionMatrix.PredictedClass("A", 5),
new MulticlassConfusionMatrix.PredictedClass("B", 3),
new MulticlassConfusionMatrix.PredictedClass("C", 9)
), 3L),
new MulticlassConfusionMatrix.ActualClass("C", 17, Arrays.asList(
new MulticlassConfusionMatrix.PredictedClass("A", 8),
new MulticlassConfusionMatrix.PredictedClass("B", 2),
new MulticlassConfusionMatrix.PredictedClass("C", 7)
), 0L)),
0)),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives
new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives
new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives
);
}
public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() {
expectThrows(
AssertionError.class,
() -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(Collections.emptyList(), 1)));
}
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
import java.io.IOException;
import java.util.Arrays;
@ -56,20 +57,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public static MulticlassConfusionMatrix createRandom() {
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
return new MulticlassConfusionMatrix(size);
return new MulticlassConfusionMatrix(size, null);
}
public void testConstructor_SizeValidationFailures() {
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
}
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
}
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
}
}
@ -84,36 +88,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class",
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
0L),
mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
result.getConfusionMatrix(),
equalTo(
@ -126,36 +130,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testEvaluate_OtherClassesCountGreaterThanZero() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class",
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L),
mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
85,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(
result.getConfusionMatrix(),
equalTo(

View File

@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
@ -22,6 +23,8 @@ import org.junit.Before;
import java.util.Arrays;
import java.util.List;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@ -53,10 +56,28 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
assertThat(
evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
contains(MulticlassConfusionMatrix.NAME.getPreferredName()));
}
public void testEvaluate_AllMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_FIELD,
ANIMAL_NAME_PREDICTION_FIELD,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
contains(
Accuracy.NAME.getPreferredName(),
MulticlassConfusionMatrix.NAME.getPreferredName(),
Precision.NAME.getPreferredName(),
Recall.NAME.getPreferredName()));
}
public void testEvaluate_Accuracy_KeywordField() {
@ -70,14 +91,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
new Accuracy.ActualClass("ant", 15, 1.0 / 15),
new Accuracy.ActualClass("cat", 15, 1.0 / 15),
new Accuracy.ActualClass("dog", 15, 1.0 / 15),
new Accuracy.ActualClass("fox", 15, 1.0 / 15),
new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
new Accuracy.PerClassResult("ant", 47.0 / 75),
new Accuracy.PerClassResult("cat", 47.0 / 75),
new Accuracy.PerClassResult("dog", 47.0 / 75),
new Accuracy.PerClassResult("fox", 47.0 / 75),
new Accuracy.PerClassResult("mouse", 47.0 / 75))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
}
@ -92,13 +113,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(Arrays.asList(
new Accuracy.ActualClass("1", 15, 1.0 / 15),
new Accuracy.ActualClass("2", 15, 2.0 / 15),
new Accuracy.ActualClass("3", 15, 3.0 / 15),
new Accuracy.ActualClass("4", 15, 4.0 / 15),
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("1", 57.0 / 75),
new Accuracy.PerClassResult("2", 54.0 / 75),
new Accuracy.PerClassResult("3", 51.0 / 75),
new Accuracy.PerClassResult("4", 48.0 / 75),
new Accuracy.PerClassResult("5", 45.0 / 75))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
}
@ -113,10 +135,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(Arrays.asList(
new Accuracy.ActualClass("true", 45, 27.0 / 45),
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
accuracyResult.getClasses(),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("false", 18.0 / 30),
new Accuracy.PerClassResult("true", 27.0 / 45))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
}
@ -252,7 +275,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));

View File

@ -466,12 +466,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
{ // Accuracy
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses();
assertThat(
actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()),
equalTo(dependentVariableValuesAsStrings));
actualClasses.forEach(
actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) {
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
}
}
{ // MulticlassConfusionMatrix

View File

@ -620,16 +620,13 @@ setup:
- match:
classification.accuracy:
actual_classes:
- actual_class: "cat"
actual_class_doc_count: 3
accuracy: 0.6666666666666666 # 2 out of 3
- actual_class: "dog"
actual_class_doc_count: 3
accuracy: 0.6666666666666666 # 2 out of 3
- actual_class: "mouse"
actual_class_doc_count: 2
accuracy: 0.5 # 1 out of 2
classes:
- class_name: "cat"
accuracy: 0.625 # 5 out of 8
- class_name: "dog"
accuracy: 0.75 # 6 out of 8
- class_name: "mouse"
accuracy: 0.875 # 7 out of 8
overall_accuracy: 0.625 # 5 out of 8
---
"Test classification precision":