[7.x] Fix accuracy metric (#50310) (#50433)

This commit is contained in:
Przemysław Witek 2019-12-20 15:34:38 +01:00 committed by GitHub
parent 14d95aae46
commit 3e3a93002f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 476 additions and 293 deletions

View File

@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
@ -35,10 +36,25 @@ import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/** /**
* {@link AccuracyMetric} is a metric that answers the question: * {@link AccuracyMetric} is a metric that answers the following two questions:
* "What fraction of examples have been classified correctly by the classifier?"
* *
* equation: accuracy = 1/n * Σ(y == y´) * 1. What is the fraction of documents for which predicted class equals the actual class?
*
* equation: overall_accuracy = 1/n * Σ(y == y')
* where: n = total number of documents
* y = document's actual class
* y' = document's predicted class
*
* 2. For any given class X, what is the fraction of documents for which either
* a) both actual and predicted class are equal to X (true positives)
* or
* b) both actual and predicted class are not equal to X (true negatives)
*
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
* where: X = class being examined
* n = total number of documents
* TP(X) = number of true positives wrt X
* TN(X) = number of true negatives wrt X
*/ */
public class AccuracyMetric implements EvaluationMetric { public class AccuracyMetric implements EvaluationMetric {
@ -78,15 +94,15 @@ public class AccuracyMetric implements EvaluationMetric {
public static class Result implements EvaluationMetric.Result { public static class Result implements EvaluationMetric.Result {
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1])); new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static { static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
} }
@ -94,13 +110,13 @@ public class AccuracyMetric implements EvaluationMetric {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
/** List of actual classes. */ /** List of per-class results. */
private final List<ActualClass> actualClasses; private final List<PerClassResult> classes;
/** Fraction of documents predicted correctly. */ /** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy; private final double overallAccuracy;
public Result(List<ActualClass> actualClasses, double overallAccuracy) { public Result(List<PerClassResult> classes, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.overallAccuracy = overallAccuracy; this.overallAccuracy = overallAccuracy;
} }
@ -109,8 +125,8 @@ public class AccuracyMetric implements EvaluationMetric {
return NAME; return NAME;
} }
public List<ActualClass> getActualClasses() { public List<PerClassResult> getClasses() {
return actualClasses; return classes;
} }
public double getOverallAccuracy() { public double getOverallAccuracy() {
@ -120,7 +136,7 @@ public class AccuracyMetric implements EvaluationMetric {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); builder.field(CLASSES.getPreferredName(), classes);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -131,52 +147,42 @@ public class AccuracyMetric implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses) return Objects.equals(this.classes, that.classes)
&& this.overallAccuracy == that.overallAccuracy; && this.overallAccuracy == that.overallAccuracy;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy); return Objects.hash(classes, overallAccuracy);
} }
} }
public static class ActualClass implements ToXContentObject { public static class PerClassResult implements ToXContentObject {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField ACCURACY = new ParseField("accuracy"); private static final ParseField ACCURACY = new ParseField("accuracy");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER = private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static { static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS); PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareDouble(constructorArg(), ACCURACY); PARSER.declareDouble(constructorArg(), ACCURACY);
} }
/** Name of the actual class. */ /** Name of the class. */
private final String actualClass; private final String className;
/** Number of documents (examples) belonging to the {code actualClass} class. */ /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
private final double accuracy; private final double accuracy;
public ActualClass( public PerClassResult(String className, double accuracy) {
String actualClass, long actualClassDocCount, double accuracy) { this.className = Objects.requireNonNull(className);
this.actualClass = Objects.requireNonNull(actualClass);
this.actualClassDocCount = actualClassDocCount;
this.accuracy = accuracy; this.accuracy = accuracy;
} }
public String getActualClass() { public String getClassName() {
return actualClass; return className;
}
public long getActualClassDocCount() {
return actualClassDocCount;
} }
public double getAccuracy() { public double getAccuracy() {
@ -186,8 +192,7 @@ public class AccuracyMetric implements EvaluationMetric {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(ACCURACY.getPreferredName(), accuracy); builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -197,15 +202,19 @@ public class AccuracyMetric implements EvaluationMetric {
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o; PerClassResult that = (PerClassResult) o;
return Objects.equals(this.actualClass, that.actualClass) return Objects.equals(this.className, that.className)
&& this.actualClassDocCount == that.actualClassDocCount
&& this.accuracy == that.accuracy; && this.accuracy == that.accuracy;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy); return Objects.hash(className, accuracy);
}
@Override
public String toString() {
return Strings.toString(this);
} }
} }
} }

View File

@ -1849,15 +1849,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME); AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat( assertThat(
accuracyResult.getActualClasses(), accuracyResult.getClasses(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly // 9 out of 10 examples were classified correctly
new AccuracyMetric.ActualClass("cat", 5, 0.6), new AccuracyMetric.PerClassResult("ant", 0.9),
// 3 out of 4 examples labeled as "dog" were classified correctly // 6 out of 10 examples were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75), new AccuracyMetric.PerClassResult("cat", 0.6),
// no examples labeled as "ant" were classified correctly // 8 out of 10 examples were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); new AccuracyMetric.PerClassResult("dog", 0.8))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
} }
{ // Precision { // Precision

View File

@ -19,7 +19,7 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification; package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
@ -41,13 +41,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result>
public static Result randomResult() { public static Result randomResult() {
int numClasses = randomIntBetween(2, 100); int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses); List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) { for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true); double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); classes.add(new PerClassResult(classNames.get(i), accuracy));
} }
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy); return new Result(classes, overallAccuracy);
} }
@Override @Override

View File

@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
* Gets the evaluation result for this metric. * Gets the evaluation result for this metric.
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
*/ */
Optional<EvaluationMetricResult> getResult(); Optional<? extends EvaluationMetricResult> getResult();
} }

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@ -29,7 +29,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.text.MessageFormat; import java.text.MessageFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
@ -40,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/** /**
* {@link Accuracy} is a metric that answers the question: * {@link Accuracy} is a metric that answers the following two questions:
* "What fraction of examples have been classified correctly by the classifier?"
* *
* equation: accuracy = 1/n * Σ(y == y´) * 1. What is the fraction of documents for which predicted class equals the actual class?
*
* equation: overall_accuracy = 1/n * Σ(y == y')
* where: n = total number of documents
* y = document's actual class
* y' = document's predicted class
*
* 2. For any given class X, what is the fraction of documents for which either
* a) both actual and predicted class are equal to X (true positives)
* or
* b) both actual and predicted class are not equal to X (true negatives)
*
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
* where: X = class being examined
* n = total number of documents
* TP(X) = number of true positives wrt X
* TN(X) = number of true negatives wrt X
*/ */
public class Accuracy implements EvaluationMetric { public class Accuracy implements EvaluationMetric {
public static final ParseField NAME = new ParseField("accuracy"); public static final ParseField NAME = new ParseField("accuracy");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
private static final String CLASSES_AGG_NAME = "classification_classes";
private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
private static String buildScript(Object...args) { private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
} }
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new); private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
@ -64,11 +77,20 @@ public class Accuracy implements EvaluationMetric {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private EvaluationMetricResult result; private static final int MAX_CLASSES_CARDINALITY = 1000;
public Accuracy() {} private final MulticlassConfusionMatrix matrix;
private final SetOnce<String> actualField = new SetOnce<>();
private final SetOnce<Double> overallAccuracy = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();
public Accuracy(StreamInput in) throws IOException {} public Accuracy() {
this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
}
public Accuracy(StreamInput in) throws IOException {
this.matrix = new MulticlassConfusionMatrix(in);
}
@Override @Override
public String getWriteableName() { public String getWriteableName() {
@ -82,43 +104,79 @@ public class Accuracy implements EvaluationMetric {
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (result != null) { // Store given {@code actualField} for the purpose of generating error message in {@code process}.
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); this.actualField.trySet(actualField);
List<AggregationBuilder> aggs = new ArrayList<>();
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
if (overallAccuracy.get() == null) {
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
} }
Script accuracyScript = new Script(buildScript(actualField, predictedField)); if (result.get() == null) {
return Tuple.tuple( Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
Arrays.asList( aggs.addAll(matrixAggs.v1());
AggregationBuilders.terms(CLASSES_AGG_NAME) pipelineAggs.addAll(matrixAggs.v2());
.field(actualField) }
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)), return Tuple.tuple(aggs, pipelineAggs);
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
Collections.emptyList());
} }
@Override @Override
public void process(Aggregations aggs) { public void process(Aggregations aggs) {
if (result != null) { if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
return; NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
overallAccuracy.set(overallAccuracyAgg.value());
} }
Terms classesAgg = aggs.get(CLASSES_AGG_NAME); matrix.process(aggs);
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); if (result.get() == null && matrix.getResult().isPresent()) {
List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size()); if (matrix.getResult().get().getOtherActualClassCount() > 0) {
for (Terms.Bucket bucket : classesAgg.getBuckets()) { // This means there were more than {@code maxClassesCardinality} buckets.
String actualClass = bucket.getKeyAsString(); // We cannot calculate per-class accuracy accurately, so we fail.
long actualClassDocCount = bucket.getDocCount(); throw ExceptionsHelper.badRequestException(
NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME); "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get());
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value())); }
result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get()));
} }
result = new Result(actualClasses, overallAccuracyAgg.value());
} }
@Override @Override
public Optional<EvaluationMetricResult> getResult() { public Optional<Result> getResult() {
return Optional.ofNullable(result); return Optional.ofNullable(result.get());
}
/**
* Computes the per-class accuracy results based on multiclass confusion matrix's result.
* Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension.
* This method is visible for testing only.
*/
static List<PerClassResult> computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) {
assert matrixResult.getOtherActualClassCount() == 0;
// Number of actual classes taken into account
int n = matrixResult.getConfusionMatrix().size();
// Total number of documents taken into account
long totalDocCount =
matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum();
List<PerClassResult> classes = new ArrayList<>(n);
for (int i = 0; i < n; ++i) {
String className = matrixResult.getConfusionMatrix().get(i).getActualClass();
// Start with the assumption that all the docs were predicted correctly.
long correctDocCount = totalDocCount;
for (int j = 0; j < n; ++j) {
if (i != j) {
// Subtract errors (false negatives)
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount();
// Subtract errors (false positives)
correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount();
}
}
// Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount();
classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount));
}
return classes;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
matrix.writeTo(out);
} }
@Override @Override
@ -132,25 +190,26 @@ public class Accuracy implements EvaluationMetric {
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
return true; Accuracy that = (Accuracy) o;
return Objects.equals(this.matrix, that.matrix);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hashCode(NAME.getPreferredName()); return Objects.hash(matrix);
} }
public static class Result implements EvaluationMetricResult { public static class Result implements EvaluationMetricResult {
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1])); new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static { static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
} }
@ -158,18 +217,18 @@ public class Accuracy implements EvaluationMetric {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
/** List of actual classes. */ /** List of per-class results. */
private final List<ActualClass> actualClasses; private final List<PerClassResult> classes;
/** Fraction of documents predicted correctly. */ /** Fraction of documents for which predicted class equals the actual class. */
private final double overallAccuracy; private final double overallAccuracy;
public Result(List<ActualClass> actualClasses, double overallAccuracy) { public Result(List<PerClassResult> classes, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES)); this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
this.overallAccuracy = overallAccuracy; this.overallAccuracy = overallAccuracy;
} }
public Result(StreamInput in) throws IOException { public Result(StreamInput in) throws IOException {
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
this.overallAccuracy = in.readDouble(); this.overallAccuracy = in.readDouble();
} }
@ -183,8 +242,8 @@ public class Accuracy implements EvaluationMetric {
return NAME.getPreferredName(); return NAME.getPreferredName();
} }
public List<ActualClass> getActualClasses() { public List<PerClassResult> getClasses() {
return actualClasses; return classes;
} }
public double getOverallAccuracy() { public double getOverallAccuracy() {
@ -193,14 +252,14 @@ public class Accuracy implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeList(actualClasses); out.writeList(classes);
out.writeDouble(overallAccuracy); out.writeDouble(overallAccuracy);
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); builder.field(CLASSES.getPreferredName(), classes);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -211,54 +270,47 @@ public class Accuracy implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses) return Objects.equals(this.classes, that.classes)
&& this.overallAccuracy == that.overallAccuracy; && this.overallAccuracy == that.overallAccuracy;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy); return Objects.hash(classes, overallAccuracy);
} }
} }
public static class ActualClass implements ToXContentObject, Writeable { public static class PerClassResult implements ToXContentObject, Writeable {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField ACCURACY = new ParseField("accuracy"); private static final ParseField ACCURACY = new ParseField("accuracy");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER = private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static { static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS); PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareDouble(constructorArg(), ACCURACY); PARSER.declareDouble(constructorArg(), ACCURACY);
} }
/** Name of the actual class. */ /** Name of the class. */
private final String actualClass; private final String className;
/** Number of documents (examples) belonging to the {code actualClass} class. */ /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
private final double accuracy; private final double accuracy;
public ActualClass( public PerClassResult(String className, double accuracy) {
String actualClass, long actualClassDocCount, double accuracy) { this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
this.actualClassDocCount = actualClassDocCount;
this.accuracy = accuracy; this.accuracy = accuracy;
} }
public ActualClass(StreamInput in) throws IOException { public PerClassResult(StreamInput in) throws IOException {
this.actualClass = in.readString(); this.className = in.readString();
this.actualClassDocCount = in.readVLong();
this.accuracy = in.readDouble(); this.accuracy = in.readDouble();
} }
public String getActualClass() { public String getClassName() {
return actualClass; return className;
} }
public double getAccuracy() { public double getAccuracy() {
@ -267,16 +319,14 @@ public class Accuracy implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualClass); out.writeString(className);
out.writeVLong(actualClassDocCount);
out.writeDouble(accuracy); out.writeDouble(accuracy);
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(ACCURACY.getPreferredName(), accuracy); builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject(); builder.endObject();
return builder; return builder;
@ -286,15 +336,14 @@ public class Accuracy implements EvaluationMetric {
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o; PerClassResult that = (PerClassResult) o;
return Objects.equals(this.actualClass, that.actualClass) return Objects.equals(this.className, that.className)
&& this.actualClassDocCount == that.actualClassDocCount
&& this.accuracy == that.accuracy; && this.accuracy == that.accuracy;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy); return Objects.hash(className, accuracy);
} }
} }
} }

View File

@ -5,6 +5,8 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
@ -53,13 +55,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix"); public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
public static final ParseField SIZE = new ParseField("size"); public static final ParseField SIZE = new ParseField("size");
public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix");
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser(); private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() { private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser = ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0])); new ConstructingObjectParser<>(
NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
parser.declareInt(optionalConstructorArg(), SIZE); parser.declareInt(optionalConstructorArg(), SIZE);
parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX);
return parser; return parser;
} }
@ -67,31 +72,39 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
private static final String OTHER_BUCKET_KEY = "_other_"; private static final String OTHER_BUCKET_KEY = "_other_";
private static final String DEFAULT_AGG_NAME_PREFIX = "";
private static final int DEFAULT_SIZE = 10; private static final int DEFAULT_SIZE = 10;
private static final int MAX_SIZE = 1000; private static final int MAX_SIZE = 1000;
private final int size; private final int size;
private List<String> topActualClassNames; private final String aggNamePrefix;
private Result result; private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
private final SetOnce<Result> result = new SetOnce<>();
public MulticlassConfusionMatrix() { public MulticlassConfusionMatrix() {
this((Integer) null); this(null, null);
} }
public MulticlassConfusionMatrix(@Nullable Integer size) { public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) {
if (size != null && (size <= 0 || size > MAX_SIZE)) { if (size != null && (size <= 0 || size > MAX_SIZE)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE); throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
} }
this.size = size != null ? size : DEFAULT_SIZE; this.size = size != null ? size : DEFAULT_SIZE;
this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX;
} }
public MulticlassConfusionMatrix(StreamInput in) throws IOException { public MulticlassConfusionMatrix(StreamInput in) throws IOException {
this.size = in.readVInt(); this.size = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
this.aggNamePrefix = in.readString();
} else {
this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX;
}
} }
@Override @Override
@ -110,30 +123,30 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (topActualClassNames == null) { // This is step 1 if (topActualClassNames.get() == null) { // This is step 1
return Tuple.tuple( return Tuple.tuple(
Arrays.asList( Arrays.asList(
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
.field(actualField) .field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)), .size(size)),
Collections.emptyList()); Collections.emptyList());
} }
if (result == null) { // This is step 2 if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersActual = KeyedFilter[] keyedFiltersActual =
topActualClassNames.stream() topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new); .toArray(KeyedFilter[]::new);
KeyedFilter[] keyedFiltersPredicted = KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream() topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new); .toArray(KeyedFilter[]::new);
return Tuple.tuple( return Tuple.tuple(
Arrays.asList( Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
.field(actualField), .field(actualField),
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) .subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
.otherBucket(true) .otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))), .otherBucketKey(OTHER_BUCKET_KEY))),
Collections.emptyList()); Collections.emptyList());
@ -143,18 +156,18 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
@Override @Override
public void process(Aggregations aggs) { public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
} }
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
for (Filters.Bucket bucket : filtersAgg.getBuckets()) { for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString(); String actualClass = bucket.getKeyAsString();
long actualClassDocCount = bucket.getDocCount(); long actualClassDocCount = bucket.getDocCount();
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS));
List<PredictedClass> predictedClasses = new ArrayList<>(); List<PredictedClass> predictedClasses = new ArrayList<>();
long otherPredictedClassDocCount = 0; long otherPredictedClassDocCount = 0;
for (Filters.Bucket subBucket : subAgg.getBuckets()) { for (Filters.Bucket subBucket : subAgg.getBuckets()) {
@ -169,18 +182,25 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
} }
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
} }
} }
private String aggName(String aggNameWithoutPrefix) {
return aggNamePrefix + aggNameWithoutPrefix;
}
@Override @Override
public Optional<EvaluationMetricResult> getResult() { public Optional<Result> getResult() {
return Optional.ofNullable(result); return Optional.ofNullable(result.get());
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(size); out.writeVInt(size);
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeString(aggNamePrefix);
}
} }
@Override @Override
@ -196,12 +216,13 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o; MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
return Objects.equals(this.size, that.size); return this.size == that.size
&& Objects.equals(this.aggNamePrefix, that.aggNamePrefix);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(size); return Objects.hash(size, aggNamePrefix);
} }
public static class Result implements EvaluationMetricResult { public static class Result implements EvaluationMetricResult {
@ -335,6 +356,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return actualClass; return actualClass;
} }
public long getActualClassDocCount() {
return actualClassDocCount;
}
public List<PredictedClass> getPredictedClasses() { public List<PredictedClass> getPredictedClasses() {
return predictedClasses; return predictedClasses;
} }
@ -411,6 +436,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
return predictedClass; return predictedClass;
} }
public long getCount() {
return count;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(predictedClass); out.writeString(predictedClass);

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -77,9 +78,9 @@ public class Precision implements EvaluationMetric {
private static final int MAX_CLASSES_CARDINALITY = 1000; private static final int MAX_CLASSES_CARDINALITY = 1000;
private String actualField; private final SetOnce<String> actualField = new SetOnce<>();
private List<String> topActualClassNames; private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
private EvaluationMetricResult result; private final SetOnce<Result> result = new SetOnce<>();
public Precision() {} public Precision() {}
@ -98,8 +99,8 @@ public class Precision implements EvaluationMetric {
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}. // Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField; this.actualField.trySet(actualField);
if (topActualClassNames == null) { // This is step 1 if (topActualClassNames.get() == null) { // This is step 1
return Tuple.tuple( return Tuple.tuple(
Arrays.asList( Arrays.asList(
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
@ -108,9 +109,9 @@ public class Precision implements EvaluationMetric {
.size(MAX_CLASSES_CARDINALITY)), .size(MAX_CLASSES_CARDINALITY)),
Collections.emptyList()); Collections.emptyList());
} }
if (result == null) { // This is step 2 if (result.get() == null) { // This is step 2
KeyedFilter[] keyedFiltersPredicted = KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream() topActualClassNames.get().stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new); .toArray(KeyedFilter[]::new);
Script script = buildScript(actualField, predictedField); Script script = buildScript(actualField, predictedField);
@ -128,18 +129,18 @@ public class Precision implements EvaluationMetric {
@Override @Override
public void process(Aggregations aggs) { public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
// We cannot calculate average precision accurately, so we fail. // We cannot calculate average precision accurately, so we fail.
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField); "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get());
} }
topActualClassNames = topActualClassNames.set(
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
} }
if (result == null && if (result.get() == null &&
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters && aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME); Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
@ -153,13 +154,13 @@ public class Precision implements EvaluationMetric {
classes.add(new PerClassResult(className, precision)); classes.add(new PerClassResult(className, precision));
} }
} }
result = new Result(classes, avgPrecisionAgg.value()); result.set(new Result(classes, avgPrecisionAgg.value()));
} }
} }
@Override @Override
public Optional<EvaluationMetricResult> getResult() { public Optional<Result> getResult() {
return Optional.ofNullable(result); return Optional.ofNullable(result.get());
} }
@Override @Override

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -71,8 +72,8 @@ public class Recall implements EvaluationMetric {
private static final int MAX_CLASSES_CARDINALITY = 1000; private static final int MAX_CLASSES_CARDINALITY = 1000;
private String actualField; private final SetOnce<String> actualField = new SetOnce<>();
private EvaluationMetricResult result; private final SetOnce<Result> result = new SetOnce<>();
public Recall() {} public Recall() {}
@ -91,8 +92,8 @@ public class Recall implements EvaluationMetric {
@Override @Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) { public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}. // Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField; this.actualField.trySet(actualField);
if (result != null) { if (result.get() != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
} }
Script script = buildScript(actualField, predictedField); Script script = buildScript(actualField, predictedField);
@ -110,7 +111,7 @@ public class Recall implements EvaluationMetric {
@Override @Override
public void process(Aggregations aggs) { public void process(Aggregations aggs) {
if (result == null && if (result.get() == null &&
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms && aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
@ -118,7 +119,7 @@ public class Recall implements EvaluationMetric {
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
// We cannot calculate average recall accurately, so we fail. // We cannot calculate average recall accurately, so we fail.
throw ExceptionsHelper.badRequestException( throw ExceptionsHelper.badRequestException(
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField); "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get());
} }
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME); NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size()); List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
@ -127,13 +128,13 @@ public class Recall implements EvaluationMetric {
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME); NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
classes.add(new PerClassResult(className, recallAgg.value())); classes.add(new PerClassResult(className, recallAgg.value()));
} }
result = new Result(classes, avgRecallAgg.value()); result.set(new Result(classes, avgRecallAgg.value()));
} }
} }
@Override @Override
public Optional<EvaluationMetricResult> getResult() { public Optional<Result> getResult() {
return Optional.ofNullable(result); return Optional.ofNullable(result.get());
} }
@Override @Override

View File

@ -8,9 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -22,13 +22,13 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result>
public static Result createRandom() { public static Result createRandom() {
int numClasses = randomIntBetween(2, 100); int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses); List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) { for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true); double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); classes.add(new PerClassResult(classNames.get(i), accuracy));
} }
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy); return new Result(classes, overallAccuracy);
} }
@Override @Override

View File

@ -5,17 +5,27 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> { public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@ -46,52 +56,114 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
public void testProcess() { public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms("classification_classes"), mockTerms(
mockSingleValue("classification_overall_accuracy", 0.8123), "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
mockSingleValue("some_other_single_metric_agg", 0.2377) Arrays.asList(
)); mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L),
mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy(); Accuracy accuracy = new Accuracy();
accuracy.process(aggs); accuracy.process(aggs);
assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123))); assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
Result result = accuracy.getResult().get();
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
result.getClasses(),
equalTo(
Arrays.asList(
new PerClassResult("dog", 0.5),
new PerClassResult("cat", 0.5))));
assertThat(result.getOverallAccuracy(), equalTo(0.5));
} }
public void testProcess_GivenMissingAgg() { public void testProcess_GivenCardinalityTooHigh() {
{ Aggregations aggs = new Aggregations(Arrays.asList(
Aggregations aggs = new Aggregations(Arrays.asList( mockTerms(
mockTerms("classification_classes"), "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
mockSingleValue("some_other_single_metric_agg", 0.2377) Arrays.asList(
)); mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
Accuracy accuracy = new Accuracy(); mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); 100L),
} mockFilters(
{ "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Aggregations aggs = new Aggregations(Arrays.asList( Arrays.asList(
mockSingleValue("classification_overall_accuracy", 0.8123), mockFiltersBucket(
mockSingleValue("some_other_single_metric_agg", 0.2377) "dog",
)); 30,
Accuracy accuracy = new Accuracy(); new Aggregations(Arrays.asList(mockFilters(
expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
} Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket(
"cat",
70,
new Aggregations(Arrays.asList(mockFilters(
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
Accuracy accuracy = new Accuracy();
accuracy.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
} }
public void testProcess_GivenAggOfWrongType() { public void testComputePerClassAccuracy() {
{ assertThat(
Aggregations aggs = new Aggregations(Arrays.asList( Accuracy.computePerClassAccuracy(
mockTerms("classification_classes"), new MulticlassConfusionMatrix.Result(
mockTerms("classification_overall_accuracy") Arrays.asList(
)); new MulticlassConfusionMatrix.ActualClass("A", 14, Arrays.asList(
Accuracy accuracy = new Accuracy(); new MulticlassConfusionMatrix.PredictedClass("A", 1),
expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); new MulticlassConfusionMatrix.PredictedClass("B", 6),
} new MulticlassConfusionMatrix.PredictedClass("C", 4)
{ ), 3L),
Aggregations aggs = new Aggregations(Arrays.asList( new MulticlassConfusionMatrix.ActualClass("B", 20, Arrays.asList(
mockSingleValue("classification_classes", 1.0), new MulticlassConfusionMatrix.PredictedClass("A", 5),
mockSingleValue("classification_overall_accuracy", 0.8123) new MulticlassConfusionMatrix.PredictedClass("B", 3),
)); new MulticlassConfusionMatrix.PredictedClass("C", 9)
Accuracy accuracy = new Accuracy(); ), 3L),
expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); new MulticlassConfusionMatrix.ActualClass("C", 17, Arrays.asList(
} new MulticlassConfusionMatrix.PredictedClass("A", 8),
new MulticlassConfusionMatrix.PredictedClass("B", 2),
new MulticlassConfusionMatrix.PredictedClass("C", 7)
), 0L)),
0)),
equalTo(
Arrays.asList(
new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives
new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives
new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives
);
}
public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() {
expectThrows(
AssertionError.class,
() -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(Collections.emptyList(), 1)));
} }
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -56,20 +57,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public static MulticlassConfusionMatrix createRandom() { public static MulticlassConfusionMatrix createRandom() {
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000); Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
return new MulticlassConfusionMatrix(size); return new MulticlassConfusionMatrix(size, null);
} }
public void testConstructor_SizeValidationFailures() { public void testConstructor_SizeValidationFailures() {
{ {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1)); ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
} }
{ {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0)); ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
} }
{ {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001)); ElasticsearchStatusException e =
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
} }
} }
@ -84,36 +88,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms( mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class", MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
0L), 0L),
mockFilters( mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class", MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket( mockFiltersBucket(
"dog", "dog",
30, 30,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket( mockFiltersBucket(
"cat", "cat",
70, 70,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs); confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat( assertThat(
result.getConfusionMatrix(), result.getConfusionMatrix(),
equalTo( equalTo(
@ -126,36 +130,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testEvaluate_OtherClassesCountGreaterThanZero() { public void testEvaluate_OtherClassesCountGreaterThanZero() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms( mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class", MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L), 100L),
mockFilters( mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class", MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket( mockFiltersBucket(
"dog", "dog",
30, 30,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockFiltersBucket( mockFiltersBucket(
"cat", "cat",
85, 85,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
confusionMatrix.process(aggs); confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); Result result = confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat( assertThat(
result.getConfusionMatrix(), result.getConfusionMatrix(),
equalTo( equalTo(

View File

@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
@ -22,6 +23,8 @@ import org.junit.Before;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
@ -53,10 +56,28 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
assertThat( assertThat(
evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); contains(MulticlassConfusionMatrix.NAME.getPreferredName()));
}
public void testEvaluate_AllMetrics() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(
ANIMAL_NAME_FIELD,
ANIMAL_NAME_PREDICTION_FIELD,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
contains(
Accuracy.NAME.getPreferredName(),
MulticlassConfusionMatrix.NAME.getPreferredName(),
Precision.NAME.getPreferredName(),
Recall.NAME.getPreferredName()));
} }
public void testEvaluate_Accuracy_KeywordField() { public void testEvaluate_Accuracy_KeywordField() {
@ -70,14 +91,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat( assertThat(
accuracyResult.getActualClasses(), accuracyResult.getClasses(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
new Accuracy.ActualClass("ant", 15, 1.0 / 15), new Accuracy.PerClassResult("ant", 47.0 / 75),
new Accuracy.ActualClass("cat", 15, 1.0 / 15), new Accuracy.PerClassResult("cat", 47.0 / 75),
new Accuracy.ActualClass("dog", 15, 1.0 / 15), new Accuracy.PerClassResult("dog", 47.0 / 75),
new Accuracy.ActualClass("fox", 15, 1.0 / 15), new Accuracy.PerClassResult("fox", 47.0 / 75),
new Accuracy.ActualClass("mouse", 15, 1.0 / 15)))); new Accuracy.PerClassResult("mouse", 47.0 / 75))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
} }
@ -92,13 +113,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat( assertThat(
accuracyResult.getActualClasses(), accuracyResult.getClasses(),
equalTo(Arrays.asList( equalTo(
new Accuracy.ActualClass("1", 15, 1.0 / 15), Arrays.asList(
new Accuracy.ActualClass("2", 15, 2.0 / 15), new Accuracy.PerClassResult("1", 57.0 / 75),
new Accuracy.ActualClass("3", 15, 3.0 / 15), new Accuracy.PerClassResult("2", 54.0 / 75),
new Accuracy.ActualClass("4", 15, 4.0 / 15), new Accuracy.PerClassResult("3", 51.0 / 75),
new Accuracy.ActualClass("5", 15, 5.0 / 15)))); new Accuracy.PerClassResult("4", 48.0 / 75),
new Accuracy.PerClassResult("5", 45.0 / 75))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
} }
@ -113,10 +135,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat( assertThat(
accuracyResult.getActualClasses(), accuracyResult.getClasses(),
equalTo(Arrays.asList( equalTo(
new Accuracy.ActualClass("true", 45, 27.0 / 45), Arrays.asList(
new Accuracy.ActualClass("false", 30, 18.0 / 30)))); new Accuracy.PerClassResult("false", 18.0 / 30),
new Accuracy.PerClassResult("true", 27.0 / 45))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
} }
@ -252,7 +275,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
EvaluateDataFrameAction.Response evaluateDataFrameResponse = EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame( evaluateDataFrame(
ANIMALS_DATA_INDEX, ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3)))); new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));

View File

@ -466,12 +466,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
{ // Accuracy { // Accuracy
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses(); for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) {
assertThat( assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()), assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
equalTo(dependentVariableValuesAsStrings)); }
actualClasses.forEach(
actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
} }
{ // MulticlassConfusionMatrix { // MulticlassConfusionMatrix

View File

@ -620,16 +620,13 @@ setup:
- match: - match:
classification.accuracy: classification.accuracy:
actual_classes: classes:
- actual_class: "cat" - class_name: "cat"
actual_class_doc_count: 3 accuracy: 0.625 # 5 out of 8
accuracy: 0.6666666666666666 # 2 out of 3 - class_name: "dog"
- actual_class: "dog" accuracy: 0.75 # 6 out of 8
actual_class_doc_count: 3 - class_name: "mouse"
accuracy: 0.6666666666666666 # 2 out of 3 accuracy: 0.875 # 7 out of 8
- actual_class: "mouse"
actual_class_doc_count: 2
accuracy: 0.5 # 1 out of 2
overall_accuracy: 0.625 # 5 out of 8 overall_accuracy: 0.625 # 5 out of 8
--- ---
"Test classification precision": "Test classification precision":