This commit is contained in:
parent
03600e4e12
commit
c7ac2011eb
|
@ -18,6 +18,7 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
|
@ -51,6 +52,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(MulticlassConfusionMatrixMetric.NAME),
|
||||
|
@ -68,6 +71,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(MulticlassConfusionMatrixMetric.NAME),
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* {@link AccuracyMetric} is a metric that answers the question:
|
||||
* "What fraction of examples have been classified correctly by the classifier?"
|
||||
*
|
||||
* equation: accuracy = 1/n * Σ(y == y´)
|
||||
*/
|
||||
public class AccuracyMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "accuracy";
|
||||
|
||||
private static final ObjectParser<AccuracyMetric, Void> PARSER = new ObjectParser<>(NAME, true, AccuracyMetric::new);
|
||||
|
||||
public static AccuracyMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public AccuracyMetric() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
||||
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of actual classes. */
|
||||
private final List<ActualClass> actualClasses;
|
||||
/** Fraction of documents predicted correctly. */
|
||||
private final double overallAccuracy;
|
||||
|
||||
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
||||
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
|
||||
this.overallAccuracy = overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public List<ActualClass> getActualClasses() {
|
||||
return actualClasses;
|
||||
}
|
||||
|
||||
public double getOverallAccuracy() {
|
||||
return overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
||||
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
||||
&& this.overallAccuracy == that.overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClasses, overallAccuracy);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ActualClass implements ToXContentObject {
|
||||
|
||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
||||
private static final ParseField ACCURACY = new ParseField("accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
||||
PARSER.declareDouble(constructorArg(), ACCURACY);
|
||||
}
|
||||
|
||||
/** Name of the actual class. */
|
||||
private final String actualClass;
|
||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
||||
private final long actualClassDocCount;
|
||||
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
||||
private final double accuracy;
|
||||
|
||||
public ActualClass(
|
||||
String actualClass, long actualClassDocCount, double accuracy) {
|
||||
this.actualClass = Objects.requireNonNull(actualClass);
|
||||
this.actualClassDocCount = actualClassDocCount;
|
||||
this.accuracy = accuracy;
|
||||
}
|
||||
|
||||
public String getActualClass() {
|
||||
return actualClass;
|
||||
}
|
||||
|
||||
public long getActualClassDocCount() {
|
||||
return actualClassDocCount;
|
||||
}
|
||||
|
||||
public double getAccuracy() {
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
||||
builder.field(ACCURACY.getPreferredName(), accuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ActualClass that = (ActualClass) o;
|
||||
return Objects.equals(this.actualClass, that.actualClass)
|
||||
&& this.actualClassDocCount == that.actualClassDocCount
|
||||
&& this.accuracy == that.accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -125,6 +125,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
|||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
|
@ -1813,6 +1814,27 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
||||
{ // Accuracy
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new AccuracyMetric.ActualClass("cat", 5, 0.6), // 3 out of 5 examples labeled as "cat" were classified correctly
|
||||
new AccuracyMetric.ActualClass("dog", 4, 0.75), // 3 out of 4 examples labeled as "dog" were classified correctly
|
||||
new AccuracyMetric.ActualClass("ant", 1, 0.0)))); // no examples labeled as "ant" were classified correctly
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
||||
}
|
||||
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
|
|
|
@ -57,6 +57,7 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
|
|||
import org.elasticsearch.client.indexlifecycle.UnfollowAction;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
|
@ -687,7 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(49, namedXContents.size());
|
||||
assertEquals(51, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -729,21 +730,23 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
AccuracyMetric.NAME,
|
||||
MulticlassConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
AccuracyMetric.NAME,
|
||||
MulticlassConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
|
|
|
@ -141,6 +141,7 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
|||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||
|
@ -3347,20 +3348,27 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
"actual_class", // <2>
|
||||
"predicted_class", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
new MulticlassConfusionMatrixMetric(3)); // <5>
|
||||
new AccuracyMetric(), // <5>
|
||||
new MulticlassConfusionMatrixMetric(3)); // <6>
|
||||
// end::evaluate-data-frame-evaluation-classification
|
||||
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||
EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
|
||||
|
||||
// tag::evaluate-data-frame-results-classification
|
||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
|
||||
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
|
||||
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
|
||||
|
||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
|
||||
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
|
||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
|
||||
|
||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
|
||||
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
|
||||
// end::evaluate-data-frame-results-classification
|
||||
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||
assertThat(accuracy, equalTo(0.6));
|
||||
|
||||
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
assertThat(
|
||||
confusionMatrix,
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class AccuracyMetricResultTests extends AbstractXContentTestCase<AccuracyMetric.Result> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AccuracyMetric.Result createTestInstance() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
|
||||
}
|
||||
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(actualClasses, overallAccuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return AccuracyMetric.Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class AccuracyMetricTests extends AbstractXContentTestCase<AccuracyMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
static AccuracyMetric createRandom() {
|
||||
return new AccuracyMetric();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AccuracyMetric createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AccuracyMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return AccuracyMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -25,6 +26,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
|
||||
|
@ -34,11 +36,10 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
return new Classification(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric()));
|
||||
static Classification createRandom() {
|
||||
List<EvaluationMetric> metrics =
|
||||
randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom()));
|
||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -32,12 +32,16 @@ public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCa
|
|||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric createTestInstance() {
|
||||
static MulticlassConfusionMatrixMetric createRandom() {
|
||||
Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null;
|
||||
return new MulticlassConfusionMatrixMetric(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrixMetric.fromXContent(parser);
|
||||
|
|
|
@ -52,7 +52,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
|
|||
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
|
||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> Multiclass confusion matrix of size 3
|
||||
<5> Accuracy
|
||||
<6> Multiclass confusion matrix of size 3
|
||||
|
||||
===== Regression
|
||||
|
||||
|
@ -101,9 +102,11 @@ include-tagged::{doc-tests-file}[{api}-results-softclassification]
|
|||
include-tagged::{doc-tests-file}[{api}-results-classification]
|
||||
--------------------------------------------------
|
||||
|
||||
<1> Fetching multiclass confusion matrix metric by name
|
||||
<2> Fetching the contents of the confusion matrix
|
||||
<3> Fetching the number of classes that were not included in the matrix
|
||||
<1> Fetching accuracy metric by name
|
||||
<2> Fetching the actual accuracy value
|
||||
<3> Fetching multiclass confusion matrix metric by name
|
||||
<4> Fetching the contents of the confusion matrix
|
||||
<5> Fetching the number of classes that were not included in the matrix
|
||||
|
||||
===== Regression
|
||||
|
||||
|
|
|
@ -148,6 +148,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
|
@ -516,6 +517,8 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
|||
MulticlassConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.Result::new),
|
||||
new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, Accuracy.NAME.getPreferredName(), Accuracy.Result::new),
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new),
|
||||
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
|
@ -48,6 +49,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
// Classification metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
|
||||
MulticlassConfusionMatrix::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
|
||||
|
||||
// Regression metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
|
||||
|
@ -78,6 +80,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError::new));
|
||||
|
@ -95,6 +98,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
Accuracy.NAME.getPreferredName(),
|
||||
Accuracy.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError.Result::new));
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* {@link Accuracy} is a metric that answers the question:
|
||||
* "What fraction of examples have been classified correctly by the classifier?"
|
||||
*
|
||||
* equation: accuracy = 1/n * Σ(y == y´)
|
||||
*/
|
||||
public class Accuracy implements ClassificationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("accuracy");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
private static final String CLASSES_AGG_NAME = "classification_classes";
|
||||
private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
|
||||
private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
||||
|
||||
private static String buildScript(Object...args) {
|
||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||
}
|
||||
|
||||
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
|
||||
|
||||
public static Accuracy fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public Accuracy() {}
|
||||
|
||||
public Accuracy(StreamInput in) throws IOException {}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
Script accuracyScript = new Script(buildScript(actualField, predictedField));
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.terms(CLASSES_AGG_NAME)
|
||||
.field(actualField)
|
||||
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
|
||||
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result != null) {
|
||||
return;
|
||||
}
|
||||
Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
|
||||
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
||||
List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
|
||||
for (Terms.Bucket bucket : classesAgg.getBuckets()) {
|
||||
String actualClass = bucket.getKeyAsString();
|
||||
long actualClassDocCount = bucket.getDocCount();
|
||||
NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
|
||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
|
||||
}
|
||||
result = new Result(actualClasses, overallAccuracyAgg.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME.getPreferredName());
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
||||
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of actual classes. */
|
||||
private final List<ActualClass> actualClasses;
|
||||
/** Fraction of documents predicted correctly. */
|
||||
private final double overallAccuracy;
|
||||
|
||||
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
||||
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
|
||||
this.overallAccuracy = overallAccuracy;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
|
||||
this.overallAccuracy = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public List<ActualClass> getActualClasses() {
|
||||
return actualClasses;
|
||||
}
|
||||
|
||||
public double getOverallAccuracy() {
|
||||
return overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeList(actualClasses);
|
||||
out.writeDouble(overallAccuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
||||
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
||||
&& this.overallAccuracy == that.overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClasses, overallAccuracy);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ActualClass implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
||||
private static final ParseField ACCURACY = new ParseField("accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
||||
PARSER.declareDouble(constructorArg(), ACCURACY);
|
||||
}
|
||||
|
||||
/** Name of the actual class. */
|
||||
private final String actualClass;
|
||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
||||
private final long actualClassDocCount;
|
||||
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
||||
private final double accuracy;
|
||||
|
||||
public ActualClass(
|
||||
String actualClass, long actualClassDocCount, double accuracy) {
|
||||
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
|
||||
this.actualClassDocCount = actualClassDocCount;
|
||||
this.accuracy = accuracy;
|
||||
}
|
||||
|
||||
public ActualClass(StreamInput in) throws IOException {
|
||||
this.actualClass = in.readString();
|
||||
this.actualClassDocCount = in.readVLong();
|
||||
this.accuracy = in.readDouble();
|
||||
}
|
||||
|
||||
public String getActualClass() {
|
||||
return actualClass;
|
||||
}
|
||||
|
||||
public double getAccuracy() {
|
||||
return accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualClass);
|
||||
out.writeVLong(actualClassDocCount);
|
||||
out.writeDouble(accuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
||||
builder.field(ACCURACY.getPreferredName(), accuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ActualClass that = (ActualClass) o;
|
||||
return Objects.equals(this.actualClass, that.actualClass)
|
||||
&& this.actualClassDocCount == that.actualClassDocCount
|
||||
&& this.accuracy == that.accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accuracy.Result> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Accuracy.Result createTestInstance() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
|
||||
}
|
||||
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(actualClasses, overallAccuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Accuracy.Result> instanceReader() {
|
||||
return Accuracy.Result::new;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||
|
||||
@Override
|
||||
protected Accuracy doParseInstance(XContentParser parser) throws IOException {
|
||||
return Accuracy.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Accuracy createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Accuracy> instanceReader() {
|
||||
return Accuracy::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static Accuracy createRandom() {
|
||||
return new Accuracy();
|
||||
}
|
||||
|
||||
public void testProcess() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createTermsAgg("classification_classes"),
|
||||
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.process(aggs);
|
||||
|
||||
assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
public void testProcess_GivenMissingAgg() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createTermsAgg("classification_classes"),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("classification_overall_accuracy", 0.8123),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenAggOfWrongType() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createTermsAgg("classification_classes"),
|
||||
createTermsAgg("classification_overall_accuracy")
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("classification_classes", 1.0),
|
||||
createSingleMetricAgg("classification_overall_accuracy", 0.8123)
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
}
|
||||
|
||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
||||
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
||||
when(agg.getName()).thenReturn(name);
|
||||
when(agg.value()).thenReturn(value);
|
||||
return agg;
|
||||
}
|
||||
|
||||
private static Terms createTermsAgg(String name) {
|
||||
Terms agg = mock(Terms.class);
|
||||
when(agg.getName()).thenReturn(name);
|
||||
return agg;
|
||||
}
|
||||
}
|
|
@ -51,10 +51,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
return new Classification(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom()));
|
||||
List<ClassificationMetric> metrics =
|
||||
randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom()));
|
||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
|
|||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
|
@ -21,6 +22,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
|
@ -44,69 +46,43 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
assertThat(
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new ActualClass("ant",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 1L),
|
||||
new PredictedClass("cat", 4L),
|
||||
new PredictedClass("dog", 3L),
|
||||
new PredictedClass("fox", 2L),
|
||||
new PredictedClass("mouse", 5L)),
|
||||
0),
|
||||
new ActualClass("cat",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 3L),
|
||||
new PredictedClass("cat", 1L),
|
||||
new PredictedClass("dog", 5L),
|
||||
new PredictedClass("fox", 4L),
|
||||
new PredictedClass("mouse", 2L)),
|
||||
0),
|
||||
new ActualClass("dog",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 4L),
|
||||
new PredictedClass("cat", 2L),
|
||||
new PredictedClass("dog", 1L),
|
||||
new PredictedClass("fox", 5L),
|
||||
new PredictedClass("mouse", 3L)),
|
||||
0),
|
||||
new ActualClass("fox",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 5L),
|
||||
new PredictedClass("cat", 3L),
|
||||
new PredictedClass("dog", 2L),
|
||||
new PredictedClass("fox", 1L),
|
||||
new PredictedClass("mouse", 4L)),
|
||||
0),
|
||||
new ActualClass("mouse",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 2L),
|
||||
new PredictedClass("cat", 5L),
|
||||
new PredictedClass("dog", 4L),
|
||||
new PredictedClass("fox", 3L),
|
||||
new PredictedClass("mouse", 1L)),
|
||||
0))));
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
|
||||
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
|
||||
public void testEvaluate_MulticlassClassification_Accuracy() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new Accuracy())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.ActualClass("ant", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("cat", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("dog", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("fox", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
|
@ -168,7 +144,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
|
|
|
@ -603,6 +603,35 @@ setup:
|
|||
}
|
||||
}
|
||||
---
|
||||
"Test classification accuracy":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { "accuracy": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match:
|
||||
classification.accuracy:
|
||||
actual_classes:
|
||||
- actual_class: "cat"
|
||||
actual_class_doc_count: 3
|
||||
accuracy: 0.6666666666666666 # 2 out of 3
|
||||
- actual_class: "dog"
|
||||
actual_class_doc_count: 3
|
||||
accuracy: 0.6666666666666666 # 2 out of 3
|
||||
- actual_class: "mouse"
|
||||
actual_class_doc_count: 2
|
||||
accuracy: 0.5 # 1 out of 2
|
||||
overall_accuracy: 0.625 # 5 out of 8
|
||||
---
|
||||
"Test classification multiclass_confusion_matrix":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
|
|
Loading…
Reference in New Issue