[7.x] Implement accuracy metric for multiclass classification (#47772) (#49430)

This commit is contained in:
Przemysław Witek 2019-11-21 15:01:18 +01:00 committed by GitHub
parent 03600e4e12
commit c7ac2011eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 914 additions and 78 deletions

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.dataframe.evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
@ -51,6 +52,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(MulticlassConfusionMatrixMetric.NAME),
@ -68,6 +71,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(MulticlassConfusionMatrixMetric.NAME),

View File

@ -0,0 +1,211 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* {@link AccuracyMetric} is a metric that answers the question:
* "What fraction of examples have been classified correctly by the classifier?"
*
* equation: accuracy = 1/n * Σ(y == y´)
*/
public class AccuracyMetric implements EvaluationMetric {
public static final String NAME = "accuracy";
private static final ObjectParser<AccuracyMetric, Void> PARSER = new ObjectParser<>(NAME, true, AccuracyMetric::new);
public static AccuracyMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public AccuracyMetric() {}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hashCode(NAME);
}
public static class Result implements EvaluationMetric.Result {
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Fraction of documents predicted correctly. */
private final double overallAccuracy;
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
this.overallAccuracy = overallAccuracy;
}
@Override
public String getMetricName() {
return NAME;
}
public List<ActualClass> getActualClasses() {
return actualClasses;
}
public double getOverallAccuracy() {
return overallAccuracy;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses)
&& this.overallAccuracy == that.overallAccuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy);
}
}
public static class ActualClass implements ToXContentObject {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField ACCURACY = new ParseField("accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareDouble(constructorArg(), ACCURACY);
}
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
private final double accuracy;
public ActualClass(
String actualClass, long actualClassDocCount, double accuracy) {
this.actualClass = Objects.requireNonNull(actualClass);
this.actualClassDocCount = actualClassDocCount;
this.accuracy = accuracy;
}
public String getActualClass() {
return actualClass;
}
public long getActualClassDocCount() {
return actualClassDocCount;
}
public double getAccuracy() {
return accuracy;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
&& this.accuracy == that.accuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy);
}
}
}

View File

@ -125,6 +125,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
@ -1813,6 +1814,27 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
{ // Accuracy
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(
accuracyResult.getActualClasses(),
equalTo(
Arrays.asList(
new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly
new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
}
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(

View File

@ -57,6 +57,7 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
import org.elasticsearch.client.indexlifecycle.UnfollowAction;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
@ -687,7 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(49, namedXContents.size());
assertEquals(51, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -729,21 +730,23 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names,
hasItems(AucRocMetric.NAME,
PrecisionMetric.NAME,
RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
AccuracyMetric.NAME,
MulticlassConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names,
hasItems(AucRocMetric.NAME,
PrecisionMetric.NAME,
RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
AccuracyMetric.NAME,
MulticlassConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));

View File

@ -141,6 +141,7 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
@ -3347,20 +3348,27 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
"actual_class", // <2>
"predicted_class", // <3>
// Evaluation metrics // <4>
new MulticlassConfusionMatrixMetric(3)); // <5>
new AccuracyMetric(), // <5>
new MulticlassConfusionMatrixMetric(3)); // <6>
// end::evaluate-data-frame-evaluation-classification
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
// tag::evaluate-data-frame-results-classification
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
// end::evaluate-data-frame-results-classification
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(accuracy, equalTo(0.6));
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,

View File

@ -0,0 +1,63 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class AccuracyMetricResultTests extends AbstractXContentTestCase<AccuracyMetric.Result> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected AccuracyMetric.Result createTestInstance() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
}
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy);
}
@Override
protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException {
return AccuracyMetric.Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class AccuracyMetricTests extends AbstractXContentTestCase<AccuracyMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
static AccuracyMetric createRandom() {
return new AccuracyMetric();
}
@Override
protected AccuracyMetric createTestInstance() {
return createRandom();
}
@Override
protected AccuracyMetric doParseInstance(XContentParser parser) throws IOException {
return AccuracyMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
@ -25,6 +26,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
@ -34,11 +36,10 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static Classification createRandom() {
return new Classification(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric()));
static Classification createRandom() {
List<EvaluationMetric> metrics =
randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}
@Override

View File

@ -32,12 +32,16 @@ public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCa
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected MulticlassConfusionMatrixMetric createTestInstance() {
static MulticlassConfusionMatrixMetric createRandom() {
Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null;
return new MulticlassConfusionMatrixMetric(size);
}
@Override
protected MulticlassConfusionMatrixMetric createTestInstance() {
return createRandom();
}
@Override
protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrixMetric.fromXContent(parser);

View File

@ -52,7 +52,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> Multiclass confusion matrix of size 3
<5> Accuracy
<6> Multiclass confusion matrix of size 3
===== Regression
@ -101,9 +102,11 @@ include-tagged::{doc-tests-file}[{api}-results-softclassification]
include-tagged::{doc-tests-file}[{api}-results-classification]
--------------------------------------------------
<1> Fetching multiclass confusion matrix metric by name
<2> Fetching the contents of the confusion matrix
<3> Fetching the number of classes that were not included in the matrix
<1> Fetching accuracy metric by name
<2> Fetching the actual accuracy value
<3> Fetching multiclass confusion matrix metric by name
<4> Fetching the contents of the confusion matrix
<5> Fetching the number of classes that were not included in the matrix
===== Regression

View File

@ -148,6 +148,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -516,6 +517,8 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
MulticlassConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix.Result::new),
new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, Accuracy.NAME.getPreferredName(), Accuracy.Result::new),
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new),
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
@ -48,6 +49,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Classification metrics
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
MulticlassConfusionMatrix::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
// Regression metrics
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
@ -78,6 +80,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError::new));
@ -95,6 +98,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
Accuracy.NAME.getPreferredName(),
Accuracy.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError.Result::new));

View File

@ -0,0 +1,294 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* {@link Accuracy} is a metric that answers the question:
* "What fraction of examples have been classified correctly by the classifier?"
*
* equation: accuracy = 1/n * Σ(y == y´)
*/
public class Accuracy implements ClassificationMetric {
public static final ParseField NAME = new ParseField("accuracy");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String CLASSES_AGG_NAME = "classification_classes";
private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
}
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
public static Accuracy fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private EvaluationMetricResult result;
public Accuracy() {}
public Accuracy(StreamInput in) throws IOException {}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
if (result != null) {
return Collections.emptyList();
}
Script accuracyScript = new Script(buildScript(actualField, predictedField));
return Arrays.asList(
AggregationBuilders.terms(CLASSES_AGG_NAME)
.field(actualField)
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript));
}
@Override
public void process(Aggregations aggs) {
if (result != null) {
return;
}
Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
for (Terms.Bucket bucket : classesAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString();
long actualClassDocCount = bucket.getDocCount();
NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
}
result = new Result(actualClasses, overallAccuracyAgg.value());
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hashCode(NAME.getPreferredName());
}
public static class Result implements EvaluationMetricResult {
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Fraction of documents predicted correctly. */
private final double overallAccuracy;
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
this.overallAccuracy = overallAccuracy;
}
public Result(StreamInput in) throws IOException {
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
this.overallAccuracy = in.readDouble();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
public List<ActualClass> getActualClasses() {
return actualClasses;
}
public double getOverallAccuracy() {
return overallAccuracy;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(actualClasses);
out.writeDouble(overallAccuracy);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.actualClasses, that.actualClasses)
&& this.overallAccuracy == that.overallAccuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClasses, overallAccuracy);
}
}
public static class ActualClass implements ToXContentObject, Writeable {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField ACCURACY = new ParseField("accuracy");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareDouble(constructorArg(), ACCURACY);
}
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
private final double accuracy;
public ActualClass(
String actualClass, long actualClassDocCount, double accuracy) {
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
this.actualClassDocCount = actualClassDocCount;
this.accuracy = accuracy;
}
public ActualClass(StreamInput in) throws IOException {
this.actualClass = in.readString();
this.actualClassDocCount = in.readVLong();
this.accuracy = in.readDouble();
}
public String getActualClass() {
return actualClass;
}
public double getAccuracy() {
return accuracy;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualClass);
out.writeVLong(actualClassDocCount);
out.writeDouble(accuracy);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(ACCURACY.getPreferredName(), accuracy);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
&& this.accuracy == that.accuracy;
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, accuracy);
}
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accuracy.Result> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
}
@Override
protected Accuracy.Result createTestInstance() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double accuracy = randomDoubleBetween(0.0, 1.0, true);
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
}
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
return new Result(actualClasses, overallAccuracy);
}
@Override
protected Writeable.Reader<Accuracy.Result> instanceReader() {
return Accuracy.Result::new;
}
}

View File

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@Override
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
return Accuracy.fromXContent(parser);
}
@Override
protected Accuracy createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Accuracy> instanceReader() {
return Accuracy::new;
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
public static Accuracy createRandom() {
return new Accuracy();
}
public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"),
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
accuracy.process(aggs);
assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123)));
}
public void testProcess_GivenMissingAgg() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
Accuracy accuracy = new Accuracy();
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
}
}
public void testProcess_GivenAggOfWrongType() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
createTermsAgg("classification_classes"),
createTermsAgg("classification_overall_accuracy")
));
Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("classification_classes", 1.0),
createSingleMetricAgg("classification_overall_accuracy", 0.8123)
));
Accuracy accuracy = new Accuracy();
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
}
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
private static Terms createTermsAgg(String name) {
Terms agg = mock(Terms.class);
when(agg.getName()).thenReturn(name);
return agg;
}
}

View File

@ -51,10 +51,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
}
public static Classification createRandom() {
return new Classification(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom()));
List<ClassificationMetric> metrics =
randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}
@Override

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
@ -21,6 +22,7 @@ import java.util.Arrays;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@ -44,69 +46,43 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
assertThat(
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new ActualClass("ant",
15,
Arrays.asList(
new PredictedClass("ant", 1L),
new PredictedClass("cat", 4L),
new PredictedClass("dog", 3L),
new PredictedClass("fox", 2L),
new PredictedClass("mouse", 5L)),
0),
new ActualClass("cat",
15,
Arrays.asList(
new PredictedClass("ant", 3L),
new PredictedClass("cat", 1L),
new PredictedClass("dog", 5L),
new PredictedClass("fox", 4L),
new PredictedClass("mouse", 2L)),
0),
new ActualClass("dog",
15,
Arrays.asList(
new PredictedClass("ant", 4L),
new PredictedClass("cat", 2L),
new PredictedClass("dog", 1L),
new PredictedClass("fox", 5L),
new PredictedClass("mouse", 3L)),
0),
new ActualClass("fox",
15,
Arrays.asList(
new PredictedClass("ant", 5L),
new PredictedClass("cat", 3L),
new PredictedClass("dog", 2L),
new PredictedClass("fox", 1L),
new PredictedClass("mouse", 4L)),
0),
new ActualClass("mouse",
15,
Arrays.asList(
new PredictedClass("ant", 2L),
new PredictedClass("cat", 5L),
new PredictedClass("dog", 4L),
new PredictedClass("fox", 3L),
new PredictedClass("mouse", 1L)),
0))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
}
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
public void testEvaluate_MulticlassClassification_Accuracy() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
assertThat(
accuracyResult.getActualClasses(),
equalTo(
Arrays.asList(
new Accuracy.ActualClass("ant", 15, 1.0 / 15),
new Accuracy.ActualClass("cat", 15, 1.0 / 15),
new Accuracy.ActualClass("dog", 15, 1.0 / 15),
new Accuracy.ActualClass("fox", 15, 1.0 / 15),
new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
}
public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
@ -168,7 +144,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));

View File

@ -603,6 +603,35 @@ setup:
}
}
---
"Test classification accuracy":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { "accuracy": {} }
}
}
}
- match:
classification.accuracy:
actual_classes:
- actual_class: "cat"
actual_class_doc_count: 3
accuracy: 0.6666666666666666 # 2 out of 3
- actual_class: "dog"
actual_class_doc_count: 3
accuracy: 0.6666666666666666 # 2 out of 3
- actual_class: "mouse"
actual_class_doc_count: 2
accuracy: 0.5 # 1 out of 2
overall_accuracy: 0.625 # 5 out of 8
---
"Test classification multiclass_confusion_matrix":
- do:
ml.evaluate_data_frame: