From c7ac2011eb50a35988a81cfef826f7c9a861d836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Thu, 21 Nov 2019 15:01:18 +0100 Subject: [PATCH] [7.x] Implement accuracy metric for multiclass classification (#47772) (#49430) --- .../MlEvaluationNamedXContentProvider.java | 5 + .../classification/AccuracyMetric.java | 211 +++++++++++++ .../client/MachineLearningIT.java | 22 ++ .../client/RestHighLevelClientTests.java | 9 +- .../MlClientDocumentationIT.java | 18 +- .../AccuracyMetricResultTests.java | 63 ++++ .../classification/AccuracyMetricTests.java | 53 ++++ .../classification/ClassificationTests.java | 11 +- .../MulticlassConfusionMatrixMetricTests.java | 8 +- .../ml/evaluate-data-frame.asciidoc | 11 +- .../xpack/core/XPackClientPlugin.java | 3 + .../MlEvaluationNamedXContentProvider.java | 6 + .../evaluation/classification/Accuracy.java | 294 ++++++++++++++++++ .../classification/AccuracyResultTests.java | 44 +++ .../classification/AccuracyTests.java | 112 +++++++ .../classification/ClassificationTests.java | 7 +- .../ClassificationEvaluationIT.java | 86 ++--- .../test/ml/evaluate_data_frame.yml | 29 ++ 18 files changed, 914 insertions(+), 78 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index dca644b663e..efe58b9739e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; @@ -51,6 +52,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(MulticlassConfusionMatrixMetric.NAME), @@ -68,6 +71,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(MulticlassConfusionMatrixMetric.NAME), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java new file mode 100644 index 00000000000..4db165be06c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java @@ -0,0 +1,211 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * {@link AccuracyMetric} is a metric that answers the question: + * "What fraction of examples have been classified correctly by the classifier?" + * + * equation: accuracy = 1/n * Σ(y == y´) + */ +public class AccuracyMetric implements EvaluationMetric { + + public static final String NAME = "accuracy"; + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, AccuracyMetric::new); + + public static AccuracyMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public AccuracyMetric() {} + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(NAME); + } + + public static class Result implements EvaluationMetric.Result { + + private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); + private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + + static { + PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); + PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** List of actual classes. */ + private final List actualClasses; + /** Fraction of documents predicted correctly. */ + private final double overallAccuracy; + + public Result(List actualClasses, double overallAccuracy) { + this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses)); + this.overallAccuracy = overallAccuracy; + } + + @Override + public String getMetricName() { + return NAME; + } + + public List getActualClasses() { + return actualClasses; + } + + public double getOverallAccuracy() { + return overallAccuracy; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); + builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.actualClasses, that.actualClasses) + && this.overallAccuracy == that.overallAccuracy; + } + + @Override + public int hashCode() { + return Objects.hash(actualClasses, overallAccuracy); + } + } + + public static class ActualClass implements ToXContentObject { + + private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField ACCURACY = new ParseField("accuracy"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareDouble(constructorArg(), ACCURACY); + } + + /** Name of the actual class. */ + private final String actualClass; + /** Number of documents (examples) belonging to the {code actualClass} class. */ + private final long actualClassDocCount; + /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */ + private final double accuracy; + + public ActualClass( + String actualClass, long actualClassDocCount, double accuracy) { + this.actualClass = Objects.requireNonNull(actualClass); + this.actualClassDocCount = actualClassDocCount; + this.accuracy = accuracy; + } + + public String getActualClass() { + return actualClass; + } + + public long getActualClassDocCount() { + return actualClassDocCount; + } + + public double getAccuracy() { + return accuracy; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(ACCURACY.getPreferredName(), accuracy); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ActualClass that = (ActualClass) o; + return Objects.equals(this.actualClass, that.actualClass) + && this.actualClassDocCount == that.actualClassDocCount + && this.accuracy == that.accuracy; + } + + @Override + public int hashCode() { + return Objects.hash(actualClass, actualClassDocCount, accuracy); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 841fc42accd..910d091c8a0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -125,6 +125,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; @@ -1813,6 +1814,27 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + { // Accuracy + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME); + assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); + assertThat( + accuracyResult.getActualClasses(), + equalTo( + Arrays.asList( + new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly + new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly + new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly + assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly + } { // No size provided for MulticlassConfusionMatrixMetric, default used instead EvaluateDataFrameRequest evaluateDataFrameRequest = new EvaluateDataFrameRequest( diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index eab3d6882e0..d2e55388602 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -57,6 +57,7 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; @@ -687,7 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(49, namedXContents.size()); + assertEquals(51, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -729,21 +730,23 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, + AccuracyMetric.NAME, MulticlassConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME, RSquaredMetric.NAME)); - assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, + AccuracyMetric.NAME, MulticlassConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME, RSquaredMetric.NAME)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 762eaaaf906..63a397eb0c0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -141,6 +141,7 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; @@ -3347,20 +3348,27 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { "actual_class", // <2> "predicted_class", // <3> // Evaluation metrics // <4> - new MulticlassConfusionMatrixMetric(3)); // <5> + new AccuracyMetric(), // <5> + new MulticlassConfusionMatrixMetric(3)); // <6> // end::evaluate-data-frame-evaluation-classification EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT); // tag::evaluate-data-frame-results-classification - MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = - response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1> + AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1> + double accuracy = accuracyResult.getOverallAccuracy(); // <2> - List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> - long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3> + MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = + response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3> + + List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4> + long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5> // end::evaluate-data-frame-results-classification + assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); + assertThat(accuracy, equalTo(0.6)); + assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat( confusionMatrix, diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java new file mode 100644 index 00000000000..4e6557b4f58 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class AccuracyMetricResultTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected AccuracyMetric.Result createTestInstance() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + List actualClasses = new ArrayList<>(numClasses); + for (int i = 0; i < numClasses; i++) { + double accuracy = randomDoubleBetween(0.0, 1.0, true); + actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); + } + double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); + return new Result(actualClasses, overallAccuracy); + } + + @Override + protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException { + return AccuracyMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricTests.java new file mode 100644 index 00000000000..06377ed0f4a --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class AccuracyMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + static AccuracyMetric createRandom() { + return new AccuracyMetric(); + } + + @Override + protected AccuracyMetric createTestInstance() { + return createRandom(); + } + + @Override + protected AccuracyMetric doParseInstance(XContentParser parser) throws IOException { + return AccuracyMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java index a72b483518c..491c74fc2e0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation.classification; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -25,6 +26,7 @@ import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; import java.util.Arrays; +import java.util.List; import java.util.function.Predicate; public class ClassificationTests extends AbstractXContentTestCase { @@ -34,11 +36,10 @@ public class ClassificationTests extends AbstractXContentTestCase metrics = + randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom())); + return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java index f4de12796f0..29835cea9e4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java @@ -32,12 +32,16 @@ public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCa return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); } - @Override - protected MulticlassConfusionMatrixMetric createTestInstance() { + static MulticlassConfusionMatrixMetric createRandom() { Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null; return new MulticlassConfusionMatrixMetric(size); } + @Override + protected MulticlassConfusionMatrixMetric createTestInstance() { + return createRandom(); + } + @Override protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException { return MulticlassConfusionMatrixMetric.fromXContent(parser); diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc index 61f18dbd092..b4abafa249e 100644 --- a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -52,7 +52,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification] <2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to. <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example. <4> The remaining parameters are the metrics to be calculated based on the two fields described above -<5> Multiclass confusion matrix of size 3 +<5> Accuracy +<6> Multiclass confusion matrix of size 3 ===== Regression @@ -101,9 +102,11 @@ include-tagged::{doc-tests-file}[{api}-results-softclassification] include-tagged::{doc-tests-file}[{api}-results-classification] -------------------------------------------------- -<1> Fetching multiclass confusion matrix metric by name -<2> Fetching the contents of the confusion matrix -<3> Fetching the number of classes that were not included in the matrix +<1> Fetching accuracy metric by name +<2> Fetching the actual accuracy value +<3> Fetching multiclass confusion matrix metric by name +<4> Fetching the contents of the confusion matrix +<5> Fetching the number of classes that were not included in the matrix ===== Regression diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 5f34bc44631..80123670a64 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -148,6 +148,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; @@ -516,6 +517,8 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl MulticlassConfusionMatrix::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(), MulticlassConfusionMatrix.Result::new), + new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, Accuracy.NAME.getPreferredName(), Accuracy.Result::new), new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new), new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 8036c5ab895..1ef8b89a996 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; @@ -48,6 +49,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Classification metrics namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME, MulticlassConfusionMatrix::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent)); // Regression metrics namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); @@ -78,6 +80,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(), MulticlassConfusionMatrix::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError::new)); @@ -95,6 +98,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(), MulticlassConfusionMatrix.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + Accuracy.NAME.getPreferredName(), + Accuracy.Result::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError.Result::new)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java new file mode 100644 index 00000000000..6acd5de4f45 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java @@ -0,0 +1,294 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * {@link Accuracy} is a metric that answers the question: + * "What fraction of examples have been classified correctly by the classifier?" + * + * equation: accuracy = 1/n * Σ(y == y´) + */ +public class Accuracy implements ClassificationMetric { + + public static final ParseField NAME = new ParseField("accuracy"); + + private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value"; + private static final String CLASSES_AGG_NAME = "classification_classes"; + private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy"; + private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy"; + + private static String buildScript(Object...args) { + return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + } + + private static final ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new); + + public static Accuracy fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private EvaluationMetricResult result; + + public Accuracy() {} + + public Accuracy(StreamInput in) throws IOException {} + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public final List aggs(String actualField, String predictedField) { + if (result != null) { + return Collections.emptyList(); + } + Script accuracyScript = new Script(buildScript(actualField, predictedField)); + return Arrays.asList( + AggregationBuilders.terms(CLASSES_AGG_NAME) + .field(actualField) + .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)), + AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)); + } + + @Override + public void process(Aggregations aggs) { + if (result != null) { + return; + } + Terms classesAgg = aggs.get(CLASSES_AGG_NAME); + NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME); + List actualClasses = new ArrayList<>(classesAgg.getBuckets().size()); + for (Terms.Bucket bucket : classesAgg.getBuckets()) { + String actualClass = bucket.getKeyAsString(); + long actualClassDocCount = bucket.getDocCount(); + NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME); + actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value())); + } + result = new Result(actualClasses, overallAccuracyAgg.value()); + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hashCode(NAME.getPreferredName()); + } + + public static class Result implements EvaluationMetricResult { + + private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes"); + private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List) a[0], (double) a[1])); + + static { + PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES); + PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** List of actual classes. */ + private final List actualClasses; + /** Fraction of documents predicted correctly. */ + private final double overallAccuracy; + + public Result(List actualClasses, double overallAccuracy) { + this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES)); + this.overallAccuracy = overallAccuracy; + } + + public Result(StreamInput in) throws IOException { + this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); + this.overallAccuracy = in.readDouble(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + public List getActualClasses() { + return actualClasses; + } + + public double getOverallAccuracy() { + return overallAccuracy; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(actualClasses); + out.writeDouble(overallAccuracy); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses); + builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.actualClasses, that.actualClasses) + && this.overallAccuracy == that.overallAccuracy; + } + + @Override + public int hashCode() { + return Objects.hash(actualClasses, overallAccuracy); + } + } + + public static class ActualClass implements ToXContentObject, Writeable { + + private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField ACCURACY = new ParseField("accuracy"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareDouble(constructorArg(), ACCURACY); + } + + /** Name of the actual class. */ + private final String actualClass; + /** Number of documents (examples) belonging to the {code actualClass} class. */ + private final long actualClassDocCount; + /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */ + private final double accuracy; + + public ActualClass( + String actualClass, long actualClassDocCount, double accuracy) { + this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS); + this.actualClassDocCount = actualClassDocCount; + this.accuracy = accuracy; + } + + public ActualClass(StreamInput in) throws IOException { + this.actualClass = in.readString(); + this.actualClassDocCount = in.readVLong(); + this.accuracy = in.readDouble(); + } + + public String getActualClass() { + return actualClass; + } + + public double getAccuracy() { + return accuracy; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualClass); + out.writeVLong(actualClassDocCount); + out.writeDouble(accuracy); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(ACCURACY.getPreferredName(), accuracy); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ActualClass that = (ActualClass) o; + return Objects.equals(this.actualClass, that.actualClass) + && this.actualClassDocCount == that.actualClassDocCount + && this.accuracy == that.accuracy; + } + + @Override + public int hashCode() { + return Objects.hash(actualClass, actualClassDocCount, accuracy); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java new file mode 100644 index 00000000000..bb3cc991920 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class AccuracyResultTests extends AbstractWireSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected Accuracy.Result createTestInstance() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + List actualClasses = new ArrayList<>(numClasses); + for (int i = 0; i < numClasses; i++) { + double accuracy = randomDoubleBetween(0.0, 1.0, true); + actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy)); + } + double overallAccuracy = randomDoubleBetween(0.0, 1.0, true); + return new Result(actualClasses, overallAccuracy); + } + + @Override + protected Writeable.Reader instanceReader() { + return Accuracy.Result::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java new file mode 100644 index 00000000000..1c4caa0c51d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AccuracyTests extends AbstractSerializingTestCase { + + @Override + protected Accuracy doParseInstance(XContentParser parser) throws IOException { + return Accuracy.fromXContent(parser); + } + + @Override + protected Accuracy createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Accuracy::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static Accuracy createRandom() { + return new Accuracy(); + } + + public void testProcess() { + Aggregations aggs = new Aggregations(Arrays.asList( + createTermsAgg("classification_classes"), + createSingleMetricAgg("classification_overall_accuracy", 0.8123), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + Accuracy accuracy = new Accuracy(); + accuracy.process(aggs); + + assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123))); + } + + public void testProcess_GivenMissingAgg() { + { + Aggregations aggs = new Aggregations(Arrays.asList( + createTermsAgg("classification_classes"), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + Accuracy accuracy = new Accuracy(); + expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); + } + { + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("classification_overall_accuracy", 0.8123), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + Accuracy accuracy = new Accuracy(); + expectThrows(NullPointerException.class, () -> accuracy.process(aggs)); + } + } + + public void testProcess_GivenAggOfWrongType() { + { + Aggregations aggs = new Aggregations(Arrays.asList( + createTermsAgg("classification_classes"), + createTermsAgg("classification_overall_accuracy") + )); + Accuracy accuracy = new Accuracy(); + expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); + } + { + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("classification_classes", 1.0), + createSingleMetricAgg("classification_overall_accuracy", 0.8123) + )); + Accuracy accuracy = new Accuracy(); + expectThrows(ClassCastException.class, () -> accuracy.process(aggs)); + } + } + + private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { + NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); + when(agg.getName()).thenReturn(name); + when(agg.value()).thenReturn(value); + return agg; + } + + private static Terms createTermsAgg(String name) { + Terms agg = mock(Terms.class); + when(agg.getName()).thenReturn(name); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java index 96cbdf843db..bceee8b399e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -51,10 +51,9 @@ public class ClassificationTests extends AbstractSerializingTestCase metrics = + randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom())); + return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index 974be222a51..c876174d290 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; @@ -21,6 +22,7 @@ import java.util.Arrays; import java.util.List; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -44,69 +46,43 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); - MulticlassConfusionMatrix.Result confusionMatrixResult = - (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); - assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); assertThat( - confusionMatrixResult.getConfusionMatrix(), - equalTo(Arrays.asList( - new ActualClass("ant", - 15, - Arrays.asList( - new PredictedClass("ant", 1L), - new PredictedClass("cat", 4L), - new PredictedClass("dog", 3L), - new PredictedClass("fox", 2L), - new PredictedClass("mouse", 5L)), - 0), - new ActualClass("cat", - 15, - Arrays.asList( - new PredictedClass("ant", 3L), - new PredictedClass("cat", 1L), - new PredictedClass("dog", 5L), - new PredictedClass("fox", 4L), - new PredictedClass("mouse", 2L)), - 0), - new ActualClass("dog", - 15, - Arrays.asList( - new PredictedClass("ant", 4L), - new PredictedClass("cat", 2L), - new PredictedClass("dog", 1L), - new PredictedClass("fox", 5L), - new PredictedClass("mouse", 3L)), - 0), - new ActualClass("fox", - 15, - Arrays.asList( - new PredictedClass("ant", 5L), - new PredictedClass("cat", 3L), - new PredictedClass("dog", 2L), - new PredictedClass("fox", 1L), - new PredictedClass("mouse", 4L)), - 0), - new ActualClass("mouse", - 15, - Arrays.asList( - new PredictedClass("ant", 2L), - new PredictedClass("cat", 5L), - new PredictedClass("dog", 4L), - new PredictedClass("fox", 3L), - new PredictedClass("mouse", 1L)), - 0)))); - assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); + evaluateDataFrameResponse.getMetrics().get(0).getMetricName(), + equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); } - public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { + public void testEvaluate_MulticlassClassification_Accuracy() { + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + evaluateDataFrame( + ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy()))); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + + Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName())); + assertThat( + accuracyResult.getActualClasses(), + equalTo( + Arrays.asList( + new Accuracy.ActualClass("ant", 15, 1.0 / 15), + new Accuracy.ActualClass("cat", 15, 1.0 / 15), + new Accuracy.ActualClass("dog", 15, 1.0 / 15), + new Accuracy.ActualClass("fox", 15, 1.0 / 15), + new Accuracy.ActualClass("mouse", 15, 1.0 / 15)))); + assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75)); + } + + public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() { EvaluateDataFrameAction.Response evaluateDataFrameResponse = evaluateDataFrame( ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); @@ -168,7 +144,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3)))); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); MulticlassConfusionMatrix.Result confusionMatrixResult = (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index f35346fc785..9d0d645e3d3 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -603,6 +603,35 @@ setup: } } --- +"Test classification accuracy": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "accuracy": {} } + } + } + } + + - match: + classification.accuracy: + actual_classes: + - actual_class: "cat" + actual_class_doc_count: 3 + accuracy: 0.6666666666666666 # 2 out of 3 + - actual_class: "dog" + actual_class_doc_count: 3 + accuracy: 0.6666666666666666 # 2 out of 3 + - actual_class: "mouse" + actual_class_doc_count: 2 + accuracy: 0.5 # 1 out of 2 + overall_accuracy: 0.625 # 5 out of 8 +--- "Test classification multiclass_confusion_matrix": - do: ml.evaluate_data_frame: