parent
14d95aae46
commit
3e3a93002f
|
@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
|||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
|
@ -35,10 +36,25 @@ import java.util.Objects;
|
|||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* {@link AccuracyMetric} is a metric that answers the question:
|
||||
* "What fraction of examples have been classified correctly by the classifier?"
|
||||
* {@link AccuracyMetric} is a metric that answers the following two questions:
|
||||
*
|
||||
* equation: accuracy = 1/n * Σ(y == y´)
|
||||
* 1. What is the fraction of documents for which predicted class equals the actual class?
|
||||
*
|
||||
* equation: overall_accuracy = 1/n * Σ(y == y')
|
||||
* where: n = total number of documents
|
||||
* y = document's actual class
|
||||
* y' = document's predicted class
|
||||
*
|
||||
* 2. For any given class X, what is the fraction of documents for which either
|
||||
* a) both actual and predicted class are equal to X (true positives)
|
||||
* or
|
||||
* b) both actual and predicted class are not equal to X (true negatives)
|
||||
*
|
||||
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
|
||||
* where: X = class being examined
|
||||
* n = total number of documents
|
||||
* TP(X) = number of true positives wrt X
|
||||
* TN(X) = number of true negatives wrt X
|
||||
*/
|
||||
public class AccuracyMetric implements EvaluationMetric {
|
||||
|
||||
|
@ -78,15 +94,15 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
||||
private static final ParseField CLASSES = new ParseField("classes");
|
||||
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
||||
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
||||
}
|
||||
|
||||
|
@ -94,13 +110,13 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of actual classes. */
|
||||
private final List<ActualClass> actualClasses;
|
||||
/** Fraction of documents predicted correctly. */
|
||||
/** List of per-class results. */
|
||||
private final List<PerClassResult> classes;
|
||||
/** Fraction of documents for which predicted class equals the actual class. */
|
||||
private final double overallAccuracy;
|
||||
|
||||
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
||||
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
|
||||
public Result(List<PerClassResult> classes, double overallAccuracy) {
|
||||
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
|
||||
this.overallAccuracy = overallAccuracy;
|
||||
}
|
||||
|
||||
|
@ -109,8 +125,8 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
return NAME;
|
||||
}
|
||||
|
||||
public List<ActualClass> getActualClasses() {
|
||||
return actualClasses;
|
||||
public List<PerClassResult> getClasses() {
|
||||
return classes;
|
||||
}
|
||||
|
||||
public double getOverallAccuracy() {
|
||||
|
@ -120,7 +136,7 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
||||
builder.field(CLASSES.getPreferredName(), classes);
|
||||
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -131,52 +147,42 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
||||
return Objects.equals(this.classes, that.classes)
|
||||
&& this.overallAccuracy == that.overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClasses, overallAccuracy);
|
||||
return Objects.hash(classes, overallAccuracy);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ActualClass implements ToXContentObject {
|
||||
public static class PerClassResult implements ToXContentObject {
|
||||
|
||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
||||
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
private static final ParseField ACCURACY = new ParseField("accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
||||
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareDouble(constructorArg(), ACCURACY);
|
||||
}
|
||||
|
||||
/** Name of the actual class. */
|
||||
private final String actualClass;
|
||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
||||
private final long actualClassDocCount;
|
||||
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
||||
/** Name of the class. */
|
||||
private final String className;
|
||||
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
|
||||
private final double accuracy;
|
||||
|
||||
public ActualClass(
|
||||
String actualClass, long actualClassDocCount, double accuracy) {
|
||||
this.actualClass = Objects.requireNonNull(actualClass);
|
||||
this.actualClassDocCount = actualClassDocCount;
|
||||
public PerClassResult(String className, double accuracy) {
|
||||
this.className = Objects.requireNonNull(className);
|
||||
this.accuracy = accuracy;
|
||||
}
|
||||
|
||||
public String getActualClass() {
|
||||
return actualClass;
|
||||
}
|
||||
|
||||
public long getActualClassDocCount() {
|
||||
return actualClassDocCount;
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public double getAccuracy() {
|
||||
|
@ -186,8 +192,7 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(ACCURACY.getPreferredName(), accuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -197,15 +202,19 @@ public class AccuracyMetric implements EvaluationMetric {
|
|||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ActualClass that = (ActualClass) o;
|
||||
return Objects.equals(this.actualClass, that.actualClass)
|
||||
&& this.actualClassDocCount == that.actualClassDocCount
|
||||
PerClassResult that = (PerClassResult) o;
|
||||
return Objects.equals(this.className, that.className)
|
||||
&& this.accuracy == that.accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
||||
return Objects.hash(className, accuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1849,15 +1849,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||
new AccuracyMetric.ActualClass("cat", 5, 0.6),
|
||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||
new AccuracyMetric.ActualClass("dog", 4, 0.75),
|
||||
// no examples labeled as "ant" were classified correctly
|
||||
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
|
||||
// 9 out of 10 examples were classified correctly
|
||||
new AccuracyMetric.PerClassResult("ant", 0.9),
|
||||
// 6 out of 10 examples were classified correctly
|
||||
new AccuracyMetric.PerClassResult("cat", 0.6),
|
||||
// 8 out of 10 examples were classified correctly
|
||||
new AccuracyMetric.PerClassResult("dog", 0.8))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
||||
}
|
||||
{ // Precision
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
@ -41,13 +41,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result>
|
|||
public static Result randomResult() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
|
||||
classes.add(new PerClassResult(classNames.get(i), accuracy));
|
||||
}
|
||||
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(actualClasses, overallAccuracy);
|
||||
return new Result(classes, overallAccuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
|||
* Gets the evaluation result for this metric.
|
||||
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
|
||||
*/
|
||||
Optional<EvaluationMetricResult> getResult();
|
||||
Optional<? extends EvaluationMetricResult> getResult();
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
@ -29,7 +29,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
|
@ -40,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
|
|||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* {@link Accuracy} is a metric that answers the question:
|
||||
* "What fraction of examples have been classified correctly by the classifier?"
|
||||
* {@link Accuracy} is a metric that answers the following two questions:
|
||||
*
|
||||
* equation: accuracy = 1/n * Σ(y == y´)
|
||||
* 1. What is the fraction of documents for which predicted class equals the actual class?
|
||||
*
|
||||
* equation: overall_accuracy = 1/n * Σ(y == y')
|
||||
* where: n = total number of documents
|
||||
* y = document's actual class
|
||||
* y' = document's predicted class
|
||||
*
|
||||
* 2. For any given class X, what is the fraction of documents for which either
|
||||
* a) both actual and predicted class are equal to X (true positives)
|
||||
* or
|
||||
* b) both actual and predicted class are not equal to X (true negatives)
|
||||
*
|
||||
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
|
||||
* where: X = class being examined
|
||||
* n = total number of documents
|
||||
* TP(X) = number of true positives wrt X
|
||||
* TN(X) = number of true negatives wrt X
|
||||
*/
|
||||
public class Accuracy implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("accuracy");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
private static final String CLASSES_AGG_NAME = "classification_classes";
|
||||
private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
|
||||
private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
||||
static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
||||
|
||||
private static String buildScript(Object...args) {
|
||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
|
||||
private static Script buildScript(Object...args) {
|
||||
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||
}
|
||||
|
||||
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
|
||||
|
@ -64,11 +77,20 @@ public class Accuracy implements EvaluationMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private EvaluationMetricResult result;
|
||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
public Accuracy() {}
|
||||
private final MulticlassConfusionMatrix matrix;
|
||||
private final SetOnce<String> actualField = new SetOnce<>();
|
||||
private final SetOnce<Double> overallAccuracy = new SetOnce<>();
|
||||
private final SetOnce<Result> result = new SetOnce<>();
|
||||
|
||||
public Accuracy(StreamInput in) throws IOException {}
|
||||
public Accuracy() {
|
||||
this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
|
||||
}
|
||||
|
||||
public Accuracy(StreamInput in) throws IOException {
|
||||
this.matrix = new MulticlassConfusionMatrix(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
|
@ -82,43 +104,79 @@ public class Accuracy implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField.trySet(actualField);
|
||||
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
|
||||
if (overallAccuracy.get() == null) {
|
||||
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
||||
}
|
||||
Script accuracyScript = new Script(buildScript(actualField, predictedField));
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(CLASSES_AGG_NAME)
|
||||
.field(actualField)
|
||||
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
|
||||
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
|
||||
Collections.emptyList());
|
||||
if (result.get() == null) {
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
|
||||
aggs.addAll(matrixAggs.v1());
|
||||
pipelineAggs.addAll(matrixAggs.v2());
|
||||
}
|
||||
return Tuple.tuple(aggs, pipelineAggs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result != null) {
|
||||
return;
|
||||
if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
||||
overallAccuracy.set(overallAccuracyAgg.value());
|
||||
}
|
||||
Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
|
||||
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
||||
List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
|
||||
for (Terms.Bucket bucket : classesAgg.getBuckets()) {
|
||||
String actualClass = bucket.getKeyAsString();
|
||||
long actualClassDocCount = bucket.getDocCount();
|
||||
NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
|
||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
|
||||
matrix.process(aggs);
|
||||
if (result.get() == null && matrix.getResult().isPresent()) {
|
||||
if (matrix.getResult().get().getOtherActualClassCount() > 0) {
|
||||
// This means there were more than {@code maxClassesCardinality} buckets.
|
||||
// We cannot calculate per-class accuracy accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get());
|
||||
}
|
||||
result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get()));
|
||||
}
|
||||
result = new Result(actualClasses, overallAccuracyAgg.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
public Optional<Result> getResult() {
|
||||
return Optional.ofNullable(result.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the per-class accuracy results based on multiclass confusion matrix's result.
|
||||
* Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension.
|
||||
* This method is visible for testing only.
|
||||
*/
|
||||
static List<PerClassResult> computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) {
|
||||
assert matrixResult.getOtherActualClassCount() == 0;
|
||||
// Number of actual classes taken into account
|
||||
int n = matrixResult.getConfusionMatrix().size();
|
||||
// Total number of documents taken into account
|
||||
long totalDocCount =
|
||||
matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum();
|
||||
List<PerClassResult> classes = new ArrayList<>(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
String className = matrixResult.getConfusionMatrix().get(i).getActualClass();
|
||||
// Start with the assumption that all the docs were predicted correctly.
|
||||
long correctDocCount = totalDocCount;
|
||||
for (int j = 0; j < n; ++j) {
|
||||
if (i != j) {
|
||||
// Subtract errors (false negatives)
|
||||
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount();
|
||||
// Subtract errors (false positives)
|
||||
correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount();
|
||||
}
|
||||
}
|
||||
// Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix
|
||||
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount();
|
||||
classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount));
|
||||
}
|
||||
return classes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
matrix.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -132,25 +190,26 @@ public class Accuracy implements EvaluationMetric {
|
|||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
Accuracy that = (Accuracy) o;
|
||||
return Objects.equals(this.matrix, that.matrix);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME.getPreferredName());
|
||||
return Objects.hash(matrix);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
||||
private static final ParseField CLASSES = new ParseField("classes");
|
||||
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
||||
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
||||
}
|
||||
|
||||
|
@ -158,18 +217,18 @@ public class Accuracy implements EvaluationMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of actual classes. */
|
||||
private final List<ActualClass> actualClasses;
|
||||
/** Fraction of documents predicted correctly. */
|
||||
/** List of per-class results. */
|
||||
private final List<PerClassResult> classes;
|
||||
/** Fraction of documents for which predicted class equals the actual class. */
|
||||
private final double overallAccuracy;
|
||||
|
||||
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
||||
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
|
||||
public Result(List<PerClassResult> classes, double overallAccuracy) {
|
||||
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
|
||||
this.overallAccuracy = overallAccuracy;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
|
||||
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
|
||||
this.overallAccuracy = in.readDouble();
|
||||
}
|
||||
|
||||
|
@ -183,8 +242,8 @@ public class Accuracy implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public List<ActualClass> getActualClasses() {
|
||||
return actualClasses;
|
||||
public List<PerClassResult> getClasses() {
|
||||
return classes;
|
||||
}
|
||||
|
||||
public double getOverallAccuracy() {
|
||||
|
@ -193,14 +252,14 @@ public class Accuracy implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeList(actualClasses);
|
||||
out.writeList(classes);
|
||||
out.writeDouble(overallAccuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
||||
builder.field(CLASSES.getPreferredName(), classes);
|
||||
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -211,54 +270,47 @@ public class Accuracy implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
||||
return Objects.equals(this.classes, that.classes)
|
||||
&& this.overallAccuracy == that.overallAccuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClasses, overallAccuracy);
|
||||
return Objects.hash(classes, overallAccuracy);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ActualClass implements ToXContentObject, Writeable {
|
||||
public static class PerClassResult implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
||||
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
private static final ParseField ACCURACY = new ParseField("accuracy");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
||||
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareDouble(constructorArg(), ACCURACY);
|
||||
}
|
||||
|
||||
/** Name of the actual class. */
|
||||
private final String actualClass;
|
||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
||||
private final long actualClassDocCount;
|
||||
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
||||
/** Name of the class. */
|
||||
private final String className;
|
||||
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
|
||||
private final double accuracy;
|
||||
|
||||
public ActualClass(
|
||||
String actualClass, long actualClassDocCount, double accuracy) {
|
||||
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
|
||||
this.actualClassDocCount = actualClassDocCount;
|
||||
public PerClassResult(String className, double accuracy) {
|
||||
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
|
||||
this.accuracy = accuracy;
|
||||
}
|
||||
|
||||
public ActualClass(StreamInput in) throws IOException {
|
||||
this.actualClass = in.readString();
|
||||
this.actualClassDocCount = in.readVLong();
|
||||
public PerClassResult(StreamInput in) throws IOException {
|
||||
this.className = in.readString();
|
||||
this.accuracy = in.readDouble();
|
||||
}
|
||||
|
||||
public String getActualClass() {
|
||||
return actualClass;
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public double getAccuracy() {
|
||||
|
@ -267,16 +319,14 @@ public class Accuracy implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualClass);
|
||||
out.writeVLong(actualClassDocCount);
|
||||
out.writeString(className);
|
||||
out.writeDouble(accuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(ACCURACY.getPreferredName(), accuracy);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -286,15 +336,14 @@ public class Accuracy implements EvaluationMetric {
|
|||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ActualClass that = (ActualClass) o;
|
||||
return Objects.equals(this.actualClass, that.actualClass)
|
||||
&& this.actualClassDocCount == that.actualClassDocCount
|
||||
PerClassResult that = (PerClassResult) o;
|
||||
return Objects.equals(this.className, that.className)
|
||||
&& this.accuracy == that.accuracy;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
||||
return Objects.hash(className, accuracy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
|
@ -53,13 +55,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
|
||||
|
||||
public static final ParseField SIZE = new ParseField("size");
|
||||
public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix");
|
||||
|
||||
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
|
||||
|
||||
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
|
||||
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
|
||||
new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
|
||||
parser.declareInt(optionalConstructorArg(), SIZE);
|
||||
parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -67,31 +72,39 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
||||
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
||||
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
||||
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
||||
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
||||
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
||||
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
||||
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
||||
private static final String OTHER_BUCKET_KEY = "_other_";
|
||||
private static final String DEFAULT_AGG_NAME_PREFIX = "";
|
||||
private static final int DEFAULT_SIZE = 10;
|
||||
private static final int MAX_SIZE = 1000;
|
||||
|
||||
private final int size;
|
||||
private List<String> topActualClassNames;
|
||||
private Result result;
|
||||
private final String aggNamePrefix;
|
||||
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
||||
private final SetOnce<Result> result = new SetOnce<>();
|
||||
|
||||
public MulticlassConfusionMatrix() {
|
||||
this((Integer) null);
|
||||
this(null, null);
|
||||
}
|
||||
|
||||
public MulticlassConfusionMatrix(@Nullable Integer size) {
|
||||
public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) {
|
||||
if (size != null && (size <= 0 || size > MAX_SIZE)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
|
||||
}
|
||||
this.size = size != null ? size : DEFAULT_SIZE;
|
||||
this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX;
|
||||
}
|
||||
|
||||
public MulticlassConfusionMatrix(StreamInput in) throws IOException {
|
||||
this.size = in.readVInt();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
this.aggNamePrefix = in.readString();
|
||||
} else {
|
||||
this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -110,30 +123,30 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (topActualClassNames == null) { // This is step 1
|
||||
if (topActualClassNames.get() == null) { // This is step 1
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
|
||||
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
if (result.get() == null) { // This is step 2
|
||||
KeyedFilter[] keyedFiltersActual =
|
||||
topActualClassNames.stream()
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.stream()
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
|
||||
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
|
||||
.field(actualField),
|
||||
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
|
||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
|
||||
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
|
||||
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
|
||||
.otherBucket(true)
|
||||
.otherBucketKey(OTHER_BUCKET_KEY))),
|
||||
Collections.emptyList());
|
||||
|
@ -143,18 +156,18 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
||||
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
|
||||
if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
||||
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
|
||||
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
||||
}
|
||||
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
||||
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
|
||||
Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
||||
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
|
||||
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
|
||||
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
|
||||
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
||||
String actualClass = bucket.getKeyAsString();
|
||||
long actualClassDocCount = bucket.getDocCount();
|
||||
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
|
||||
Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS));
|
||||
List<PredictedClass> predictedClasses = new ArrayList<>();
|
||||
long otherPredictedClassDocCount = 0;
|
||||
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
|
||||
|
@ -169,18 +182,25 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
||||
}
|
||||
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
|
||||
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
|
||||
}
|
||||
}
|
||||
|
||||
private String aggName(String aggNameWithoutPrefix) {
|
||||
return aggNamePrefix + aggNameWithoutPrefix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
public Optional<Result> getResult() {
|
||||
return Optional.ofNullable(result.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(size);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||
out.writeString(aggNamePrefix);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -196,12 +216,13 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
|
||||
return Objects.equals(this.size, that.size);
|
||||
return this.size == that.size
|
||||
&& Objects.equals(this.aggNamePrefix, that.aggNamePrefix);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(size);
|
||||
return Objects.hash(size, aggNamePrefix);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
@ -335,6 +356,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
return actualClass;
|
||||
}
|
||||
|
||||
public long getActualClassDocCount() {
|
||||
return actualClassDocCount;
|
||||
}
|
||||
|
||||
public List<PredictedClass> getPredictedClasses() {
|
||||
return predictedClasses;
|
||||
}
|
||||
|
@ -411,6 +436,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
|||
return predictedClass;
|
||||
}
|
||||
|
||||
public long getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(predictedClass);
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -77,9 +78,9 @@ public class Precision implements EvaluationMetric {
|
|||
|
||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
private String actualField;
|
||||
private List<String> topActualClassNames;
|
||||
private EvaluationMetricResult result;
|
||||
private final SetOnce<String> actualField = new SetOnce<>();
|
||||
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
||||
private final SetOnce<Result> result = new SetOnce<>();
|
||||
|
||||
public Precision() {}
|
||||
|
||||
|
@ -98,8 +99,8 @@ public class Precision implements EvaluationMetric {
|
|||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField = actualField;
|
||||
if (topActualClassNames == null) { // This is step 1
|
||||
this.actualField.trySet(actualField);
|
||||
if (topActualClassNames.get() == null) { // This is step 1
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
|
||||
|
@ -108,9 +109,9 @@ public class Precision implements EvaluationMetric {
|
|||
.size(MAX_CLASSES_CARDINALITY)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
if (result.get() == null) { // This is step 2
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.stream()
|
||||
topActualClassNames.get().stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
Script script = buildScript(actualField, predictedField);
|
||||
|
@ -128,18 +129,18 @@ public class Precision implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
|
||||
if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
|
||||
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
|
||||
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
|
||||
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
||||
// We cannot calculate average precision accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
|
||||
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get());
|
||||
}
|
||||
topActualClassNames =
|
||||
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
|
||||
topActualClassNames.set(
|
||||
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
||||
}
|
||||
if (result == null &&
|
||||
if (result.get() == null &&
|
||||
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
|
||||
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
|
||||
|
@ -153,13 +154,13 @@ public class Precision implements EvaluationMetric {
|
|||
classes.add(new PerClassResult(className, precision));
|
||||
}
|
||||
}
|
||||
result = new Result(classes, avgPrecisionAgg.value());
|
||||
result.set(new Result(classes, avgPrecisionAgg.value()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
public Optional<Result> getResult() {
|
||||
return Optional.ofNullable(result.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
|
@ -71,8 +72,8 @@ public class Recall implements EvaluationMetric {
|
|||
|
||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
private String actualField;
|
||||
private EvaluationMetricResult result;
|
||||
private final SetOnce<String> actualField = new SetOnce<>();
|
||||
private final SetOnce<Result> result = new SetOnce<>();
|
||||
|
||||
public Recall() {}
|
||||
|
||||
|
@ -91,8 +92,8 @@ public class Recall implements EvaluationMetric {
|
|||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField = actualField;
|
||||
if (result != null) {
|
||||
this.actualField.trySet(actualField);
|
||||
if (result.get() != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
Script script = buildScript(actualField, predictedField);
|
||||
|
@ -110,7 +111,7 @@ public class Recall implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result == null &&
|
||||
if (result.get() == null &&
|
||||
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
|
||||
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
|
||||
|
@ -118,7 +119,7 @@ public class Recall implements EvaluationMetric {
|
|||
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
||||
// We cannot calculate average recall accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
|
||||
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get());
|
||||
}
|
||||
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
|
||||
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
|
||||
|
@ -127,13 +128,13 @@ public class Recall implements EvaluationMetric {
|
|||
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
|
||||
classes.add(new PerClassResult(className, recallAgg.value()));
|
||||
}
|
||||
result = new Result(classes, avgRecallAgg.value());
|
||||
result.set(new Result(classes, avgRecallAgg.value()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
public Optional<Result> getResult() {
|
||||
return Optional.ofNullable(result.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -8,9 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -22,13 +22,13 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result>
|
|||
public static Result createRandom() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
|
||||
classes.add(new PerClassResult(classNames.get(i), accuracy));
|
||||
}
|
||||
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(actualClasses, overallAccuracy);
|
||||
return new Result(classes, overallAccuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -5,17 +5,27 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||
|
@ -46,52 +56,114 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
|||
|
||||
public void testProcess() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms("classification_classes"),
|
||||
mockSingleValue("classification_overall_accuracy", 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
mockTerms(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
30,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
70,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
|
||||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||
|
||||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.process(aggs);
|
||||
|
||||
assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123)));
|
||||
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
|
||||
Result result = accuracy.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
result.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new PerClassResult("dog", 0.5),
|
||||
new PerClassResult("cat", 0.5))));
|
||||
assertThat(result.getOverallAccuracy(), equalTo(0.5));
|
||||
}
|
||||
|
||||
public void testProcess_GivenMissingAgg() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms("classification_classes"),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue("classification_overall_accuracy", 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
public void testProcess_GivenCardinalityTooHigh() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
30,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
70,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
|
||||
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||
|
||||
Accuracy accuracy = new Accuracy();
|
||||
accuracy.aggs("foo", "bar");
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
|
||||
public void testProcess_GivenAggOfWrongType() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms("classification_classes"),
|
||||
mockTerms("classification_overall_accuracy")
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue("classification_classes", 1.0),
|
||||
mockSingleValue("classification_overall_accuracy", 0.8123)
|
||||
));
|
||||
Accuracy accuracy = new Accuracy();
|
||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
||||
}
|
||||
public void testComputePerClassAccuracy() {
|
||||
assertThat(
|
||||
Accuracy.computePerClassAccuracy(
|
||||
new MulticlassConfusionMatrix.Result(
|
||||
Arrays.asList(
|
||||
new MulticlassConfusionMatrix.ActualClass("A", 14, Arrays.asList(
|
||||
new MulticlassConfusionMatrix.PredictedClass("A", 1),
|
||||
new MulticlassConfusionMatrix.PredictedClass("B", 6),
|
||||
new MulticlassConfusionMatrix.PredictedClass("C", 4)
|
||||
), 3L),
|
||||
new MulticlassConfusionMatrix.ActualClass("B", 20, Arrays.asList(
|
||||
new MulticlassConfusionMatrix.PredictedClass("A", 5),
|
||||
new MulticlassConfusionMatrix.PredictedClass("B", 3),
|
||||
new MulticlassConfusionMatrix.PredictedClass("C", 9)
|
||||
), 3L),
|
||||
new MulticlassConfusionMatrix.ActualClass("C", 17, Arrays.asList(
|
||||
new MulticlassConfusionMatrix.PredictedClass("A", 8),
|
||||
new MulticlassConfusionMatrix.PredictedClass("B", 2),
|
||||
new MulticlassConfusionMatrix.PredictedClass("C", 7)
|
||||
), 0L)),
|
||||
0)),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives
|
||||
new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives
|
||||
new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives
|
||||
);
|
||||
}
|
||||
|
||||
public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() {
|
||||
expectThrows(
|
||||
AssertionError.class,
|
||||
() -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(Collections.emptyList(), 1)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
@ -56,20 +57,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
|
||||
public static MulticlassConfusionMatrix createRandom() {
|
||||
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
|
||||
return new MulticlassConfusionMatrix(size);
|
||||
return new MulticlassConfusionMatrix(size, null);
|
||||
}
|
||||
|
||||
public void testConstructor_SizeValidationFailures() {
|
||||
{
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null));
|
||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||
}
|
||||
{
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null));
|
||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||
}
|
||||
{
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null));
|
||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||
}
|
||||
}
|
||||
|
@ -84,36 +88,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
public void testEvaluate() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
"multiclass_confusion_matrix_step_1_by_actual_class",
|
||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
0L),
|
||||
mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
30,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
70,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
result.getConfusionMatrix(),
|
||||
equalTo(
|
||||
|
@ -126,36 +130,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
"multiclass_confusion_matrix_step_1_by_actual_class",
|
||||
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
30,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
85,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
|
||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
|
||||
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
Result result = confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
result.getConfusionMatrix(),
|
||||
equalTo(
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
|
|||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
|
@ -22,6 +23,8 @@ import org.junit.Before;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
@ -53,10 +56,28 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
assertThat(
|
||||
evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
|
||||
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||
contains(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_AllMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(
|
||||
ANIMAL_NAME_FIELD,
|
||||
ANIMAL_NAME_PREDICTION_FIELD,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||
contains(
|
||||
Accuracy.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
Precision.NAME.getPreferredName(),
|
||||
Recall.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_Accuracy_KeywordField() {
|
||||
|
@ -70,14 +91,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.ActualClass("ant", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("cat", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("dog", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("fox", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
|
||||
new Accuracy.PerClassResult("ant", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("cat", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("dog", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("fox", 47.0 / 75),
|
||||
new Accuracy.PerClassResult("mouse", 47.0 / 75))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
|
@ -92,13 +113,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
equalTo(Arrays.asList(
|
||||
new Accuracy.ActualClass("1", 15, 1.0 / 15),
|
||||
new Accuracy.ActualClass("2", 15, 2.0 / 15),
|
||||
new Accuracy.ActualClass("3", 15, 3.0 / 15),
|
||||
new Accuracy.ActualClass("4", 15, 4.0 / 15),
|
||||
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("1", 57.0 / 75),
|
||||
new Accuracy.PerClassResult("2", 54.0 / 75),
|
||||
new Accuracy.PerClassResult("3", 51.0 / 75),
|
||||
new Accuracy.PerClassResult("4", 48.0 / 75),
|
||||
new Accuracy.PerClassResult("5", 45.0 / 75))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
|
||||
}
|
||||
|
||||
|
@ -113,10 +135,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
accuracyResult.getActualClasses(),
|
||||
equalTo(Arrays.asList(
|
||||
new Accuracy.ActualClass("true", 45, 27.0 / 45),
|
||||
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
|
||||
accuracyResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Accuracy.PerClassResult("false", 18.0 / 30),
|
||||
new Accuracy.PerClassResult("true", 27.0 / 45))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
||||
}
|
||||
|
||||
|
@ -252,7 +275,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
|
|
@ -466,12 +466,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
{ // Accuracy
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||
List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses();
|
||||
assertThat(
|
||||
actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()),
|
||||
equalTo(dependentVariableValuesAsStrings));
|
||||
actualClasses.forEach(
|
||||
actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||
for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) {
|
||||
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
|
||||
assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
|
||||
}
|
||||
}
|
||||
|
||||
{ // MulticlassConfusionMatrix
|
||||
|
|
|
@ -620,16 +620,13 @@ setup:
|
|||
|
||||
- match:
|
||||
classification.accuracy:
|
||||
actual_classes:
|
||||
- actual_class: "cat"
|
||||
actual_class_doc_count: 3
|
||||
accuracy: 0.6666666666666666 # 2 out of 3
|
||||
- actual_class: "dog"
|
||||
actual_class_doc_count: 3
|
||||
accuracy: 0.6666666666666666 # 2 out of 3
|
||||
- actual_class: "mouse"
|
||||
actual_class_doc_count: 2
|
||||
accuracy: 0.5 # 1 out of 2
|
||||
classes:
|
||||
- class_name: "cat"
|
||||
accuracy: 0.625 # 5 out of 8
|
||||
- class_name: "dog"
|
||||
accuracy: 0.75 # 6 out of 8
|
||||
- class_name: "mouse"
|
||||
accuracy: 0.875 # 7 out of 8
|
||||
overall_accuracy: 0.625 # 5 out of 8
|
||||
---
|
||||
"Test classification precision":
|
||||
|
|
Loading…
Reference in New Issue