diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java index 4db165be06c..151783499e4 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java @@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; @@ -35,10 +36,25 @@ import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; /** - * {@link AccuracyMetric} is a metric that answers the question: - * "What fraction of examples have been classified correctly by the classifier?" + * {@link AccuracyMetric} is a metric that answers the following two questions: * - * equation: accuracy = 1/n * Σ(y == y´) + * 1. What is the fraction of documents for which predicted class equals the actual class? + * + * equation: overall_accuracy = 1/n * Σ(y == y') + * where: n = total number of documents + * y = document's actual class + * y' = document's predicted class + * + * 2. For any given class X, what is the fraction of documents for which either + * a) both actual and predicted class are equal to X (true positives) + * or + * b) both actual and predicted class are not equal to X (true negatives) + * + * equation: accuracy(X) = 1/n * (TP(X) + TN(X)) + * where: X = class being examined + * n = total number of documents + * TP(X) = number of true positives wrt X + * TN(X) = number of true negatives wrt X */ public class AccuracyMetric implements EvaluationMetric { @@ -78,15 +94,15 @@ public class AccuracyMetric implements EvaluationMetric { public static class Result implements EvaluationMetric.Result { - private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); + private static final ParseField CLASSES = new ParseField("classes"); private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); } @@ -94,13 +110,13 @@ public class AccuracyMetric implements EvaluationMetric { return PARSER.apply(parser, null); } - /** List of actual classes. */ - private final List actualClasses; - /** Fraction of documents predicted correctly. */ + /** List of per-class results. */ + private final List classes; + /** Fraction of documents for which predicted class equals the actual class. */ private final double overallAccuracy; - public Result(List actualClasses, double overallAccuracy) { - this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); + public Result(List classes, double overallAccuracy) { + this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes)); this.overallAccuracy = overallAccuracy; } @@ -109,8 +125,8 @@ public class AccuracyMetric implements EvaluationMetric { return NAME; } - public List getActualClasses() { - return actualClasses; + public List getClasses() { + return classes; } public double getOverallAccuracy() { @@ -120,7 +136,7 @@ public class AccuracyMetric implements EvaluationMetric { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); + builder.field(CLASSES.getPreferredName(), classes); builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); builder.endObject(); return builder; @@ -131,52 +147,42 @@ public class AccuracyMetric implements EvaluationMetric { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return Objects.equals(this.actualClasses, that.actualClasses) + return Objects.equals(this.classes, that.classes) && this.overallAccuracy == that.overallAccuracy; } @Override public int hashCode() { - return Objects.hash(actualClasses, overallAccuracy); + return Objects.hash(classes, overallAccuracy); } } - public static class ActualClass implements ToXContentObject { + public static class PerClassResult implements ToXContentObject { - private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); - private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField CLASS_NAME = new ParseField("class_name"); private static final ParseField ACCURACY = new ParseField("accuracy"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); static { - PARSER.declareString(constructorArg(), ACTUAL_CLASS); - PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareString(constructorArg(), CLASS_NAME); PARSER.declareDouble(constructorArg(), ACCURACY); } - /** Name of the actual class. */ - private final String actualClass; - /** Number of documents (examples) belonging to the {code actualClass} class. */ - private final long actualClassDocCount; - /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */ + /** Name of the class. */ + private final String className; + /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */ private final double accuracy; - public ActualClass( - String actualClass, long actualClassDocCount, double accuracy) { - this.actualClass = Objects.requireNonNull(actualClass); - this.actualClassDocCount = actualClassDocCount; + public PerClassResult(String className, double accuracy) { + this.className = Objects.requireNonNull(className); this.accuracy = accuracy; } - public String getActualClass() { - return actualClass; - } - - public long getActualClassDocCount() { - return actualClassDocCount; + public String getClassName() { + return className; } public double getAccuracy() { @@ -186,8 +192,7 @@ public class AccuracyMetric implements EvaluationMetric { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); - builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(CLASS_NAME.getPreferredName(), className); builder.field(ACCURACY.getPreferredName(), accuracy); builder.endObject(); return builder; @@ -197,15 +202,19 @@ public class AccuracyMetric implements EvaluationMetric { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ActualClass that = (ActualClass) o; - return Objects.equals(this.actualClass, that.actualClass) - && this.actualClassDocCount == that.actualClassDocCount + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) && this.accuracy == that.accuracy; } @Override public int hashCode() { - return Objects.hash(actualClass, actualClassDocCount, accuracy); + return Objects.hash(className, accuracy); + } + + @Override + public String toString() { + return Strings.toString(this); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 18a02f2a460..bf1923c4324 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1849,15 +1849,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME); assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat( - accuracyResult.getActualClasses(), + accuracyResult.getClasses(), equalTo( Arrays.asList( - // 3 out of 5 examples labeled as "cat" were classified correctly - new AccuracyMetric.ActualClass("cat", 5, 0.6), - // 3 out of 4 examples labeled as "dog" were classified correctly - new AccuracyMetric.ActualClass("dog", 4, 0.75), - // no examples labeled as "ant" were classified correctly - new AccuracyMetric.ActualClass("ant", 1, 0.0)))); + // 9 out of 10 examples were classified correctly + new AccuracyMetric.PerClassResult("ant", 0.9), + // 6 out of 10 examples were classified correctly + new AccuracyMetric.PerClassResult("cat", 0.6), + // 8 out of 10 examples were classified correctly + new AccuracyMetric.PerClassResult("dog", 0.8)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly } { // Precision diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java index df48ef3123d..8758cea86c4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java @@ -19,7 +19,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; -import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -41,13 +41,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase public static Result randomResult() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List actualClasses = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double accuracy = randomDoubleBetween(0.0, 1.0, true); - actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); + classes.add(new PerClassResult(classNames.get(i), accuracy)); } double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); - return new Result(actualClasses, overallAccuracy); + return new Result(classes, overallAccuracy); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java index 36bf7634cb4..8a106175ace 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java @@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable { * Gets the evaluation result for this metric. * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise */ - Optional getResult(); + Optional getResult(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java index 01f303caf84..c6636329a65 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; -import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; @@ -29,7 +29,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.text.MessageFormat; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Locale; @@ -40,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName; /** - * {@link Accuracy} is a metric that answers the question: - * "What fraction of examples have been classified correctly by the classifier?" + * {@link Accuracy} is a metric that answers the following two questions: * - * equation: accuracy = 1/n * Σ(y == y´) + * 1. What is the fraction of documents for which predicted class equals the actual class? + * + * equation: overall_accuracy = 1/n * Σ(y == y') + * where: n = total number of documents + * y = document's actual class + * y' = document's predicted class + * + * 2. For any given class X, what is the fraction of documents for which either + * a) both actual and predicted class are equal to X (true positives) + * or + * b) both actual and predicted class are not equal to X (true negatives) + * + * equation: accuracy(X) = 1/n * (TP(X) + TN(X)) + * where: X = class being examined + * n = total number of documents + * TP(X) = number of true positives wrt X + * TN(X) = number of true negatives wrt X */ public class Accuracy implements EvaluationMetric { public static final ParseField NAME = new ParseField("accuracy"); - private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; - private static final String CLASSES_AGG_NAME = "classification_classes"; - private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy"; - private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; + static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; - private static String buildScript(Object...args) { - return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; + + private static Script buildScript(Object...args) { + return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args)); } private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new); @@ -64,11 +77,20 @@ public class Accuracy implements EvaluationMetric { return PARSER.apply(parser, null); } - private EvaluationMetricResult result; + private static final int MAX_CLASSES_CARDINALITY = 1000; - public Accuracy() {} + private final MulticlassConfusionMatrix matrix; + private final SetOnce actualField = new SetOnce<>(); + private final SetOnce overallAccuracy = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); - public Accuracy(StreamInput in) throws IOException {} + public Accuracy() { + this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_"); + } + + public Accuracy(StreamInput in) throws IOException { + this.matrix = new MulticlassConfusionMatrix(in); + } @Override public String getWriteableName() { @@ -82,43 +104,79 @@ public class Accuracy implements EvaluationMetric { @Override public final Tuple, List> aggs(String actualField, String predictedField) { - if (result != null) { - return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); + // Store given {@code actualField} for the purpose of generating error message in {@code process}. + this.actualField.trySet(actualField); + List aggs = new ArrayList<>(); + List pipelineAggs = new ArrayList<>(); + if (overallAccuracy.get() == null) { + aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField))); } - Script accuracyScript = new Script(buildScript(actualField, predictedField)); - return Tuple.tuple( - Arrays.asList( - AggregationBuilders.terms(CLASSES_AGG_NAME) - .field(actualField) - .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)), - AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)), - Collections.emptyList()); + if (result.get() == null) { + Tuple, List> matrixAggs = matrix.aggs(actualField, predictedField); + aggs.addAll(matrixAggs.v1()); + pipelineAggs.addAll(matrixAggs.v2()); + } + return Tuple.tuple(aggs, pipelineAggs); } @Override public void process(Aggregations aggs) { - if (result != null) { - return; + if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { + NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); + overallAccuracy.set(overallAccuracyAgg.value()); } - Terms classesAgg = aggs.get(CLASSES_AGG_NAME); - NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); - List actualClasses = new ArrayList<>(classesAgg.getBuckets().size()); - for (Terms.Bucket bucket : classesAgg.getBuckets()) { - String actualClass = bucket.getKeyAsString(); - long actualClassDocCount = bucket.getDocCount(); - NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME); - actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value())); + matrix.process(aggs); + if (result.get() == null && matrix.getResult().isPresent()) { + if (matrix.getResult().get().getOtherActualClassCount() > 0) { + // This means there were more than {@code maxClassesCardinality} buckets. + // We cannot calculate per-class accuracy accurately, so we fail. + throw ExceptionsHelper.badRequestException( + "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get()); + } + result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get())); } - result = new Result(actualClasses, overallAccuracyAgg.value()); } @Override - public Optional getResult() { - return Optional.ofNullable(result); + public Optional getResult() { + return Optional.ofNullable(result.get()); + } + + /** + * Computes the per-class accuracy results based on multiclass confusion matrix's result. + * Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension. + * This method is visible for testing only. + */ + static List computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) { + assert matrixResult.getOtherActualClassCount() == 0; + // Number of actual classes taken into account + int n = matrixResult.getConfusionMatrix().size(); + // Total number of documents taken into account + long totalDocCount = + matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum(); + List classes = new ArrayList<>(n); + for (int i = 0; i < n; ++i) { + String className = matrixResult.getConfusionMatrix().get(i).getActualClass(); + // Start with the assumption that all the docs were predicted correctly. + long correctDocCount = totalDocCount; + for (int j = 0; j < n; ++j) { + if (i != j) { + // Subtract errors (false negatives) + correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount(); + // Subtract errors (false positives) + correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount(); + } + } + // Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix + correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount(); + classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount)); + } + return classes; } @Override public void writeTo(StreamOutput out) throws IOException { + matrix.writeTo(out); } @Override @@ -132,25 +190,26 @@ public class Accuracy implements EvaluationMetric { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - return true; + Accuracy that = (Accuracy) o; + return Objects.equals(this.matrix, that.matrix); } @Override public int hashCode() { - return Objects.hashCode(NAME.getPreferredName()); + return Objects.hash(matrix); } public static class Result implements EvaluationMetricResult { - private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); + private static final ParseField CLASSES = new ParseField("classes"); private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); static { - PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); + PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES); PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); } @@ -158,18 +217,18 @@ public class Accuracy implements EvaluationMetric { return PARSER.apply(parser, null); } - /** List of actual classes. */ - private final List actualClasses; - /** Fraction of documents predicted correctly. */ + /** List of per-class results. */ + private final List classes; + /** Fraction of documents for which predicted class equals the actual class. */ private final double overallAccuracy; - public Result(List actualClasses, double overallAccuracy) { - this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES)); + public Result(List classes, double overallAccuracy) { + this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES)); this.overallAccuracy = overallAccuracy; } public Result(StreamInput in) throws IOException { - this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); + this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new)); this.overallAccuracy = in.readDouble(); } @@ -183,8 +242,8 @@ public class Accuracy implements EvaluationMetric { return NAME.getPreferredName(); } - public List getActualClasses() { - return actualClasses; + public List getClasses() { + return classes; } public double getOverallAccuracy() { @@ -193,14 +252,14 @@ public class Accuracy implements EvaluationMetric { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeList(actualClasses); + out.writeList(classes); out.writeDouble(overallAccuracy); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); + builder.field(CLASSES.getPreferredName(), classes); builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); builder.endObject(); return builder; @@ -211,54 +270,47 @@ public class Accuracy implements EvaluationMetric { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return Objects.equals(this.actualClasses, that.actualClasses) + return Objects.equals(this.classes, that.classes) && this.overallAccuracy == that.overallAccuracy; } @Override public int hashCode() { - return Objects.hash(actualClasses, overallAccuracy); + return Objects.hash(classes, overallAccuracy); } } - public static class ActualClass implements ToXContentObject, Writeable { + public static class PerClassResult implements ToXContentObject, Writeable { - private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); - private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField CLASS_NAME = new ParseField("class_name"); private static final ParseField ACCURACY = new ParseField("accuracy"); @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1])); static { - PARSER.declareString(constructorArg(), ACTUAL_CLASS); - PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareString(constructorArg(), CLASS_NAME); PARSER.declareDouble(constructorArg(), ACCURACY); } - /** Name of the actual class. */ - private final String actualClass; - /** Number of documents (examples) belonging to the {code actualClass} class. */ - private final long actualClassDocCount; - /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */ + /** Name of the class. */ + private final String className; + /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */ private final double accuracy; - public ActualClass( - String actualClass, long actualClassDocCount, double accuracy) { - this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS); - this.actualClassDocCount = actualClassDocCount; + public PerClassResult(String className, double accuracy) { + this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME); this.accuracy = accuracy; } - public ActualClass(StreamInput in) throws IOException { - this.actualClass = in.readString(); - this.actualClassDocCount = in.readVLong(); + public PerClassResult(StreamInput in) throws IOException { + this.className = in.readString(); this.accuracy = in.readDouble(); } - public String getActualClass() { - return actualClass; + public String getClassName() { + return className; } public double getAccuracy() { @@ -267,16 +319,14 @@ public class Accuracy implements EvaluationMetric { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(actualClass); - out.writeVLong(actualClassDocCount); + out.writeString(className); out.writeDouble(accuracy); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); - builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(CLASS_NAME.getPreferredName(), className); builder.field(ACCURACY.getPreferredName(), accuracy); builder.endObject(); return builder; @@ -286,15 +336,14 @@ public class Accuracy implements EvaluationMetric { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ActualClass that = (ActualClass) o; - return Objects.equals(this.actualClass, that.actualClass) - && this.actualClassDocCount == that.actualClassDocCount + PerClassResult that = (PerClassResult) o; + return Objects.equals(this.className, that.className) && this.accuracy == that.accuracy; } @Override public int hashCode() { - return Objects.hash(actualClass, actualClassDocCount, accuracy); + return Objects.hash(className, accuracy); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index 4f049efead3..8376382e41b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; @@ -53,13 +55,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { public static final ParseField NAME = new ParseField("multiclass_confusion_matrix"); public static final ParseField SIZE = new ParseField("size"); + public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix"); private static final ConstructingObjectParser PARSER = createParser(); private static ConstructingObjectParser createParser() { ConstructingObjectParser parser = - new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0])); + new ConstructingObjectParser<>( + NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1])); parser.declareInt(optionalConstructorArg(), SIZE); + parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX); return parser; } @@ -67,31 +72,39 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { return PARSER.apply(parser, null); } - private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; - private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; - private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; - private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; + static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; + static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; + static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; + static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; private static final String OTHER_BUCKET_KEY = "_other_"; + private static final String DEFAULT_AGG_NAME_PREFIX = ""; private static final int DEFAULT_SIZE = 10; private static final int MAX_SIZE = 1000; private final int size; - private List topActualClassNames; - private Result result; + private final String aggNamePrefix; + private final SetOnce> topActualClassNames = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public MulticlassConfusionMatrix() { - this((Integer) null); + this(null, null); } - public MulticlassConfusionMatrix(@Nullable Integer size) { + public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) { if (size != null && (size <= 0 || size > MAX_SIZE)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE); } this.size = size != null ? size : DEFAULT_SIZE; + this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX; } public MulticlassConfusionMatrix(StreamInput in) throws IOException { this.size = in.readVInt(); + if (in.getVersion().onOrAfter(Version.V_7_6_0)) { + this.aggNamePrefix = in.readString(); + } else { + this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX; + } } @Override @@ -110,30 +123,30 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { @Override public final Tuple, List> aggs(String actualField, String predictedField) { - if (topActualClassNames == null) { // This is step 1 + if (topActualClassNames.get() == null) { // This is step 1 return Tuple.tuple( Arrays.asList( - AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) + AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) .field(actualField) .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) .size(size)), Collections.emptyList()); } - if (result == null) { // This is step 2 + if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersActual = - topActualClassNames.stream() + topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) .toArray(KeyedFilter[]::new); KeyedFilter[] keyedFiltersPredicted = - topActualClassNames.stream() + topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); return Tuple.tuple( Arrays.asList( - AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) + AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)) .field(actualField), - AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) - .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) + AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual) + .subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted) .otherBucket(true) .otherBucketKey(OTHER_BUCKET_KEY))), Collections.emptyList()); @@ -143,18 +156,18 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { @Override public void process(Aggregations aggs) { - if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { - Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); - topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); + if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) { + Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)); + topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList())); } - if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { - Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); - Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); + if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) { + Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)); + Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)); List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); for (Filters.Bucket bucket : filtersAgg.getBuckets()) { String actualClass = bucket.getKeyAsString(); long actualClassDocCount = bucket.getDocCount(); - Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); + Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS)); List predictedClasses = new ArrayList<>(); long otherPredictedClassDocCount = 0; for (Filters.Bucket subBucket : subAgg.getBuckets()) { @@ -169,18 +182,25 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); } - result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); + result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0))); } } + private String aggName(String aggNameWithoutPrefix) { + return aggNamePrefix + aggNameWithoutPrefix; + } + @Override - public Optional getResult() { - return Optional.ofNullable(result); + public Optional getResult() { + return Optional.ofNullable(result.get()); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeVInt(size); + if (out.getVersion().onOrAfter(Version.V_7_6_0)) { + out.writeString(aggNamePrefix); + } } @Override @@ -196,12 +216,13 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o; - return Objects.equals(this.size, that.size); + return this.size == that.size + && Objects.equals(this.aggNamePrefix, that.aggNamePrefix); } @Override public int hashCode() { - return Objects.hash(size); + return Objects.hash(size, aggNamePrefix); } public static class Result implements EvaluationMetricResult { @@ -335,6 +356,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { return actualClass; } + public long getActualClassDocCount() { + return actualClassDocCount; + } + public List getPredictedClasses() { return predictedClasses; } @@ -411,6 +436,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric { return predictedClass; } + public long getCount() { + return count; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(predictedClass); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java index 6ef2aeb1a86..30906efd41f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -77,9 +78,9 @@ public class Precision implements EvaluationMetric { private static final int MAX_CLASSES_CARDINALITY = 1000; - private String actualField; - private List topActualClassNames; - private EvaluationMetricResult result; + private final SetOnce actualField = new SetOnce<>(); + private final SetOnce> topActualClassNames = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public Precision() {} @@ -98,8 +99,8 @@ public class Precision implements EvaluationMetric { @Override public final Tuple, List> aggs(String actualField, String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. - this.actualField = actualField; - if (topActualClassNames == null) { // This is step 1 + this.actualField.trySet(actualField); + if (topActualClassNames.get() == null) { // This is step 1 return Tuple.tuple( Arrays.asList( AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME) @@ -108,9 +109,9 @@ public class Precision implements EvaluationMetric { .size(MAX_CLASSES_CARDINALITY)), Collections.emptyList()); } - if (result == null) { // This is step 2 + if (result.get() == null) { // This is step 2 KeyedFilter[] keyedFiltersPredicted = - topActualClassNames.stream() + topActualClassNames.get().stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); Script script = buildScript(actualField, predictedField); @@ -128,18 +129,18 @@ public class Precision implements EvaluationMetric { @Override public void process(Aggregations aggs) { - if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { + if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) { Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME); if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) { // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // We cannot calculate average precision accurately, so we fail. throw ExceptionsHelper.badRequestException( - "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField); + "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get()); } - topActualClassNames = - topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); + topActualClassNames.set( + topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList())); } - if (result == null && + if (result.get() == null && aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters && aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME); @@ -153,13 +154,13 @@ public class Precision implements EvaluationMetric { classes.add(new PerClassResult(className, precision)); } } - result = new Result(classes, avgPrecisionAgg.value()); + result.set(new Result(classes, avgPrecisionAgg.value())); } } @Override - public Optional getResult() { - return Optional.ofNullable(result); + public Optional getResult() { + return Optional.ofNullable(result.get()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java index 27ad7a8b3bc..c4f2e8e60ab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.apache.lucene.util.SetOnce; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; @@ -71,8 +72,8 @@ public class Recall implements EvaluationMetric { private static final int MAX_CLASSES_CARDINALITY = 1000; - private String actualField; - private EvaluationMetricResult result; + private final SetOnce actualField = new SetOnce<>(); + private final SetOnce result = new SetOnce<>(); public Recall() {} @@ -91,8 +92,8 @@ public class Recall implements EvaluationMetric { @Override public final Tuple, List> aggs(String actualField, String predictedField) { // Store given {@code actualField} for the purpose of generating error message in {@code process}. - this.actualField = actualField; - if (result != null) { + this.actualField.trySet(actualField); + if (result.get() != null) { return Tuple.tuple(Collections.emptyList(), Collections.emptyList()); } Script script = buildScript(actualField, predictedField); @@ -110,7 +111,7 @@ public class Recall implements EvaluationMetric { @Override public void process(Aggregations aggs) { - if (result == null && + if (result.get() == null && aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms && aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) { Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME); @@ -118,7 +119,7 @@ public class Recall implements EvaluationMetric { // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets. // We cannot calculate average recall accurately, so we fail. throw ExceptionsHelper.badRequestException( - "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField); + "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get()); } NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME); List classes = new ArrayList<>(byActualClassAgg.getBuckets().size()); @@ -127,13 +128,13 @@ public class Recall implements EvaluationMetric { NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME); classes.add(new PerClassResult(className, recallAgg.value())); } - result = new Result(classes, avgRecallAgg.value()); + result.set(new Result(classes, avgRecallAgg.value())); } } @Override - public Optional getResult() { - return Optional.ofNullable(result); + public Optional getResult() { + return Optional.ofNullable(result.get()); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java index 8fb4c6c0240..176aa6e9a30 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.util.ArrayList; import java.util.List; @@ -22,13 +22,13 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - List actualClasses = new ArrayList<>(numClasses); + List classes = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { double accuracy = randomDoubleBetween(0.0, 1.0, true); - actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); + classes.add(new PerClassResult(classNames.get(i), accuracy)); } double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); - return new Result(actualClasses, overallAccuracy); + return new Result(classes, overallAccuracy); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java index c5e36564c57..6bf72ba73fa 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -5,17 +5,27 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; import java.io.IOException; import java.util.Arrays; import java.util.Collections; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket; +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; public class AccuracyTests extends AbstractSerializingTestCase { @@ -46,52 +56,114 @@ public class AccuracyTests extends AbstractSerializingTestCase { public void testProcess() { Aggregations aggs = new Aggregations(Arrays.asList( - mockTerms("classification_classes"), - mockSingleValue("classification_overall_accuracy", 0.8123), - mockSingleValue("some_other_single_metric_agg", 0.2377) - )); + mockTerms( + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockTermsBucket("dog", new Aggregations(Collections.emptyList())), + mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), + 100L), + mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockFiltersBucket( + "dog", + 30, + new Aggregations(Arrays.asList(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 70, + new Aggregations(Arrays.asList(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L), + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); Accuracy accuracy = new Accuracy(); accuracy.process(aggs); - assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123))); + assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty())); + + Result result = accuracy.getResult().get(); + assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + result.getClasses(), + equalTo( + Arrays.asList( + new PerClassResult("dog", 0.5), + new PerClassResult("cat", 0.5)))); + assertThat(result.getOverallAccuracy(), equalTo(0.5)); } - public void testProcess_GivenMissingAgg() { - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockTerms("classification_classes"), - mockSingleValue("some_other_single_metric_agg", 0.2377) - )); - Accuracy accuracy = new Accuracy(); - expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); - } - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockSingleValue("classification_overall_accuracy", 0.8123), - mockSingleValue("some_other_single_metric_agg", 0.2377) - )); - Accuracy accuracy = new Accuracy(); - expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); - } + public void testProcess_GivenCardinalityTooHigh() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms( + "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockTermsBucket("dog", new Aggregations(Collections.emptyList())), + mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), + 100L), + mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, + Arrays.asList( + mockFiltersBucket( + "dog", + 30, + new Aggregations(Arrays.asList(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockFiltersBucket( + "cat", + 70, + new Aggregations(Arrays.asList(mockFilters( + "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, + Arrays.asList( + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), + mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L), + mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5))); + + Accuracy accuracy = new Accuracy(); + accuracy.aggs("foo", "bar"); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs)); + assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high")); } - public void testProcess_GivenAggOfWrongType() { - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockTerms("classification_classes"), - mockTerms("classification_overall_accuracy") - )); - Accuracy accuracy = new Accuracy(); - expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); - } - { - Aggregations aggs = new Aggregations(Arrays.asList( - mockSingleValue("classification_classes", 1.0), - mockSingleValue("classification_overall_accuracy", 0.8123) - )); - Accuracy accuracy = new Accuracy(); - expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); - } + public void testComputePerClassAccuracy() { + assertThat( + Accuracy.computePerClassAccuracy( + new MulticlassConfusionMatrix.Result( + Arrays.asList( + new MulticlassConfusionMatrix.ActualClass("A", 14, Arrays.asList( + new MulticlassConfusionMatrix.PredictedClass("A", 1), + new MulticlassConfusionMatrix.PredictedClass("B", 6), + new MulticlassConfusionMatrix.PredictedClass("C", 4) + ), 3L), + new MulticlassConfusionMatrix.ActualClass("B", 20, Arrays.asList( + new MulticlassConfusionMatrix.PredictedClass("A", 5), + new MulticlassConfusionMatrix.PredictedClass("B", 3), + new MulticlassConfusionMatrix.PredictedClass("C", 9) + ), 3L), + new MulticlassConfusionMatrix.ActualClass("C", 17, Arrays.asList( + new MulticlassConfusionMatrix.PredictedClass("A", 8), + new MulticlassConfusionMatrix.PredictedClass("B", 2), + new MulticlassConfusionMatrix.PredictedClass("C", 7) + ), 0L)), + 0)), + equalTo( + Arrays.asList( + new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives + new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives + new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives + ); + } + + public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() { + expectThrows( + AssertionError.class, + () -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(Collections.emptyList(), 1))); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index da0778db140..fce65077996 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; import java.io.IOException; import java.util.Arrays; @@ -56,20 +57,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< public static MulticlassConfusionMatrix createRandom() { Integer size = randomBoolean() ? null : randomIntBetween(1, 1000); - return new MulticlassConfusionMatrix(size); + return new MulticlassConfusionMatrix(size, null); } public void testConstructor_SizeValidationFailures() { { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1)); + ElasticsearchStatusException e = + expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null)); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); } { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0)); + ElasticsearchStatusException e = + expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null)); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); } { - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001)); + ElasticsearchStatusException e = + expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null)); assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); } } @@ -84,36 +88,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< public void testEvaluate() { Aggregations aggs = new Aggregations(Arrays.asList( mockTerms( - "multiclass_confusion_matrix_step_1_by_actual_class", + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 0L), mockFilters( - "multiclass_confusion_matrix_step_2_by_actual_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( mockFiltersBucket( "dog", 30, new Aggregations(Arrays.asList(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", 70, new Aggregations(Arrays.asList(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), - mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); + mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L))); - MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); - assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( result.getConfusionMatrix(), equalTo( @@ -126,36 +130,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< public void testEvaluate_OtherClassesCountGreaterThanZero() { Aggregations aggs = new Aggregations(Arrays.asList( mockTerms( - "multiclass_confusion_matrix_step_1_by_actual_class", + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 100L), mockFilters( - "multiclass_confusion_matrix_step_2_by_actual_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS, Arrays.asList( mockFiltersBucket( "dog", 30, new Aggregations(Arrays.asList(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket( "cat", 85, new Aggregations(Arrays.asList(mockFilters( - "multiclass_confusion_matrix_step_2_by_predicted_class", + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS, Arrays.asList( mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), - mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); + mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L))); - MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null); confusionMatrix.process(aggs); assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty())); - MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); - assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Result result = confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat( result.getConfusionMatrix(), equalTo( diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 8c5987675f1..c2758a2b653 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall; @@ -22,6 +23,8 @@ import org.junit.Before; import java.util.Arrays; import java.util.List; +import static java.util.stream.Collectors.toList; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -53,10 +56,28 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); assertThat( - evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), - equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + contains(MulticlassConfusionMatrix.NAME.getPreferredName())); + } + + public void testEvaluate_AllMetrics() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, + new Classification( + ANIMAL_NAME_FIELD, + ANIMAL_NAME_PREDICTION_FIELD, + Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat( + evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()), + contains( + Accuracy.NAME.getPreferredName(), + MulticlassConfusionMatrix.NAME.getPreferredName(), + Precision.NAME.getPreferredName(), + Recall.NAME.getPreferredName())); } public void testEvaluate_Accuracy_KeywordField() { @@ -70,14 +91,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - accuracyResult.getActualClasses(), + accuracyResult.getClasses(), equalTo( Arrays.asList( - new Accuracy.ActualClass("ant", 15, 1.0 / 15), - new Accuracy.ActualClass("cat", 15, 1.0 / 15), - new Accuracy.ActualClass("dog", 15, 1.0 / 15), - new Accuracy.ActualClass("fox", 15, 1.0 / 15), - new Accuracy.ActualClass("mouse", 15, 1.0 / 15)))); + new Accuracy.PerClassResult("ant", 47.0 / 75), + new Accuracy.PerClassResult("cat", 47.0 / 75), + new Accuracy.PerClassResult("dog", 47.0 / 75), + new Accuracy.PerClassResult("fox", 47.0 / 75), + new Accuracy.PerClassResult("mouse", 47.0 / 75)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); } @@ -92,13 +113,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - accuracyResult.getActualClasses(), - equalTo(Arrays.asList( - new Accuracy.ActualClass("1", 15, 1.0 / 15), - new Accuracy.ActualClass("2", 15, 2.0 / 15), - new Accuracy.ActualClass("3", 15, 3.0 / 15), - new Accuracy.ActualClass("4", 15, 4.0 / 15), - new Accuracy.ActualClass("5", 15, 5.0 / 15)))); + accuracyResult.getClasses(), + equalTo( + Arrays.asList( + new Accuracy.PerClassResult("1", 57.0 / 75), + new Accuracy.PerClassResult("2", 54.0 / 75), + new Accuracy.PerClassResult("3", 51.0 / 75), + new Accuracy.PerClassResult("4", 48.0 / 75), + new Accuracy.PerClassResult("5", 45.0 / 75)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75)); } @@ -113,10 +135,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); assertThat( - accuracyResult.getActualClasses(), - equalTo(Arrays.asList( - new Accuracy.ActualClass("true", 45, 27.0 / 45), - new Accuracy.ActualClass("false", 30, 18.0 / 30)))); + accuracyResult.getClasses(), + equalTo( + Arrays.asList( + new Accuracy.PerClassResult("false", 18.0 / 30), + new Accuracy.PerClassResult("true", 27.0 / 45)))); assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75)); } @@ -252,7 +275,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, - new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3)))); + new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index f5f7b3d326f..632a4ee794e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -466,12 +466,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { { // Accuracy Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); - List actualClasses = accuracyResult.getActualClasses(); - assertThat( - actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()), - equalTo(dependentVariableValuesAsStrings)); - actualClasses.forEach( - actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); + for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) { + assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings))); + assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); + } } { // MulticlassConfusionMatrix diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 95a7ef4e332..2b16b79ac84 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -620,16 +620,13 @@ setup: - match: classification.accuracy: - actual_classes: - - actual_class: "cat" - actual_class_doc_count: 3 - accuracy: 0.6666666666666666 # 2 out of 3 - - actual_class: "dog" - actual_class_doc_count: 3 - accuracy: 0.6666666666666666 # 2 out of 3 - - actual_class: "mouse" - actual_class_doc_count: 2 - accuracy: 0.5 # 1 out of 2 + classes: + - class_name: "cat" + accuracy: 0.625 # 5 out of 8 + - class_name: "dog" + accuracy: 0.75 # 6 out of 8 + - class_name: "mouse" + accuracy: 0.875 # 7 out of 8 overall_accuracy: 0.625 # 5 out of 8 --- "Test classification precision":