parent
14d95aae46
commit
3e3a93002f
|
@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
@ -35,10 +36,25 @@ import java.util.Objects;
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link AccuracyMetric} is a metric that answers the question:
|
* {@link AccuracyMetric} is a metric that answers the following two questions:
|
||||||
* "What fraction of examples have been classified correctly by the classifier?"
|
|
||||||
*
|
*
|
||||||
* equation: accuracy = 1/n * Σ(y == y´)
|
* 1. What is the fraction of documents for which predicted class equals the actual class?
|
||||||
|
*
|
||||||
|
* equation: overall_accuracy = 1/n * Σ(y == y')
|
||||||
|
* where: n = total number of documents
|
||||||
|
* y = document's actual class
|
||||||
|
* y' = document's predicted class
|
||||||
|
*
|
||||||
|
* 2. For any given class X, what is the fraction of documents for which either
|
||||||
|
* a) both actual and predicted class are equal to X (true positives)
|
||||||
|
* or
|
||||||
|
* b) both actual and predicted class are not equal to X (true negatives)
|
||||||
|
*
|
||||||
|
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
|
||||||
|
* where: X = class being examined
|
||||||
|
* n = total number of documents
|
||||||
|
* TP(X) = number of true positives wrt X
|
||||||
|
* TN(X) = number of true negatives wrt X
|
||||||
*/
|
*/
|
||||||
public class AccuracyMetric implements EvaluationMetric {
|
public class AccuracyMetric implements EvaluationMetric {
|
||||||
|
|
||||||
|
@ -78,15 +94,15 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
|
|
||||||
public static class Result implements EvaluationMetric.Result {
|
public static class Result implements EvaluationMetric.Result {
|
||||||
|
|
||||||
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
private static final ParseField CLASSES = new ParseField("classes");
|
||||||
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||||
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,13 +110,13 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** List of actual classes. */
|
/** List of per-class results. */
|
||||||
private final List<ActualClass> actualClasses;
|
private final List<PerClassResult> classes;
|
||||||
/** Fraction of documents predicted correctly. */
|
/** Fraction of documents for which predicted class equals the actual class. */
|
||||||
private final double overallAccuracy;
|
private final double overallAccuracy;
|
||||||
|
|
||||||
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
public Result(List<PerClassResult> classes, double overallAccuracy) {
|
||||||
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
|
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
|
||||||
this.overallAccuracy = overallAccuracy;
|
this.overallAccuracy = overallAccuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +125,8 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
return NAME;
|
return NAME;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<ActualClass> getActualClasses() {
|
public List<PerClassResult> getClasses() {
|
||||||
return actualClasses;
|
return classes;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getOverallAccuracy() {
|
public double getOverallAccuracy() {
|
||||||
|
@ -120,7 +136,7 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
builder.field(CLASSES.getPreferredName(), classes);
|
||||||
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -131,52 +147,42 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
Result that = (Result) o;
|
Result that = (Result) o;
|
||||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
return Objects.equals(this.classes, that.classes)
|
||||||
&& this.overallAccuracy == that.overallAccuracy;
|
&& this.overallAccuracy == that.overallAccuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(actualClasses, overallAccuracy);
|
return Objects.hash(classes, overallAccuracy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class ActualClass implements ToXContentObject {
|
public static class PerClassResult implements ToXContentObject {
|
||||||
|
|
||||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
|
||||||
private static final ParseField ACCURACY = new ParseField("accuracy");
|
private static final ParseField ACCURACY = new ParseField("accuracy");
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||||
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
|
||||||
PARSER.declareDouble(constructorArg(), ACCURACY);
|
PARSER.declareDouble(constructorArg(), ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Name of the actual class. */
|
/** Name of the class. */
|
||||||
private final String actualClass;
|
private final String className;
|
||||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
|
||||||
private final long actualClassDocCount;
|
|
||||||
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
|
||||||
private final double accuracy;
|
private final double accuracy;
|
||||||
|
|
||||||
public ActualClass(
|
public PerClassResult(String className, double accuracy) {
|
||||||
String actualClass, long actualClassDocCount, double accuracy) {
|
this.className = Objects.requireNonNull(className);
|
||||||
this.actualClass = Objects.requireNonNull(actualClass);
|
|
||||||
this.actualClassDocCount = actualClassDocCount;
|
|
||||||
this.accuracy = accuracy;
|
this.accuracy = accuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getActualClass() {
|
public String getClassName() {
|
||||||
return actualClass;
|
return className;
|
||||||
}
|
|
||||||
|
|
||||||
public long getActualClassDocCount() {
|
|
||||||
return actualClassDocCount;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getAccuracy() {
|
public double getAccuracy() {
|
||||||
|
@ -186,8 +192,7 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
|
||||||
builder.field(ACCURACY.getPreferredName(), accuracy);
|
builder.field(ACCURACY.getPreferredName(), accuracy);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -197,15 +202,19 @@ public class AccuracyMetric implements EvaluationMetric {
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
ActualClass that = (ActualClass) o;
|
PerClassResult that = (PerClassResult) o;
|
||||||
return Objects.equals(this.actualClass, that.actualClass)
|
return Objects.equals(this.className, that.className)
|
||||||
&& this.actualClassDocCount == that.actualClassDocCount
|
|
||||||
&& this.accuracy == that.accuracy;
|
&& this.accuracy == that.accuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
return Objects.hash(className, accuracy);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Strings.toString(this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1849,15 +1849,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
|
AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
|
||||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||||
assertThat(
|
assertThat(
|
||||||
accuracyResult.getActualClasses(),
|
accuracyResult.getClasses(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
// 9 out of 10 examples were classified correctly
|
||||||
new AccuracyMetric.ActualClass("cat", 5, 0.6),
|
new AccuracyMetric.PerClassResult("ant", 0.9),
|
||||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
// 6 out of 10 examples were classified correctly
|
||||||
new AccuracyMetric.ActualClass("dog", 4, 0.75),
|
new AccuracyMetric.PerClassResult("cat", 0.6),
|
||||||
// no examples labeled as "ant" were classified correctly
|
// 8 out of 10 examples were classified correctly
|
||||||
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
|
new AccuracyMetric.PerClassResult("dog", 0.8))));
|
||||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
||||||
}
|
}
|
||||||
{ // Precision
|
{ // Precision
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
@ -41,13 +41,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result>
|
||||||
public static Result randomResult() {
|
public static Result randomResult() {
|
||||||
int numClasses = randomIntBetween(2, 100);
|
int numClasses = randomIntBetween(2, 100);
|
||||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||||
for (int i = 0; i < numClasses; i++) {
|
for (int i = 0; i < numClasses; i++) {
|
||||||
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
|
classes.add(new PerClassResult(classNames.get(i), accuracy));
|
||||||
}
|
}
|
||||||
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||||
return new Result(actualClasses, overallAccuracy);
|
return new Result(classes, overallAccuracy);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
||||||
* Gets the evaluation result for this metric.
|
* Gets the evaluation result for this metric.
|
||||||
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
|
* @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
|
||||||
*/
|
*/
|
||||||
Optional<EvaluationMetricResult> getResult();
|
Optional<? extends EvaluationMetricResult> getResult();
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.SetOnce;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
@ -29,7 +29,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.text.MessageFormat;
|
import java.text.MessageFormat;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
@ -40,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link Accuracy} is a metric that answers the question:
|
* {@link Accuracy} is a metric that answers the following two questions:
|
||||||
* "What fraction of examples have been classified correctly by the classifier?"
|
|
||||||
*
|
*
|
||||||
* equation: accuracy = 1/n * Σ(y == y´)
|
* 1. What is the fraction of documents for which predicted class equals the actual class?
|
||||||
|
*
|
||||||
|
* equation: overall_accuracy = 1/n * Σ(y == y')
|
||||||
|
* where: n = total number of documents
|
||||||
|
* y = document's actual class
|
||||||
|
* y' = document's predicted class
|
||||||
|
*
|
||||||
|
* 2. For any given class X, what is the fraction of documents for which either
|
||||||
|
* a) both actual and predicted class are equal to X (true positives)
|
||||||
|
* or
|
||||||
|
* b) both actual and predicted class are not equal to X (true negatives)
|
||||||
|
*
|
||||||
|
* equation: accuracy(X) = 1/n * (TP(X) + TN(X))
|
||||||
|
* where: X = class being examined
|
||||||
|
* n = total number of documents
|
||||||
|
* TP(X) = number of true positives wrt X
|
||||||
|
* TN(X) = number of true negatives wrt X
|
||||||
*/
|
*/
|
||||||
public class Accuracy implements EvaluationMetric {
|
public class Accuracy implements EvaluationMetric {
|
||||||
|
|
||||||
public static final ParseField NAME = new ParseField("accuracy");
|
public static final ParseField NAME = new ParseField("accuracy");
|
||||||
|
|
||||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
||||||
private static final String CLASSES_AGG_NAME = "classification_classes";
|
|
||||||
private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
|
|
||||||
private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
|
||||||
|
|
||||||
private static String buildScript(Object...args) {
|
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
|
||||||
|
private static Script buildScript(Object...args) {
|
||||||
|
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
|
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
|
||||||
|
@ -64,11 +77,20 @@ public class Accuracy implements EvaluationMetric {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private EvaluationMetricResult result;
|
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||||
|
|
||||||
public Accuracy() {}
|
private final MulticlassConfusionMatrix matrix;
|
||||||
|
private final SetOnce<String> actualField = new SetOnce<>();
|
||||||
|
private final SetOnce<Double> overallAccuracy = new SetOnce<>();
|
||||||
|
private final SetOnce<Result> result = new SetOnce<>();
|
||||||
|
|
||||||
public Accuracy(StreamInput in) throws IOException {}
|
public Accuracy() {
|
||||||
|
this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
|
||||||
|
}
|
||||||
|
|
||||||
|
public Accuracy(StreamInput in) throws IOException {
|
||||||
|
this.matrix = new MulticlassConfusionMatrix(in);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getWriteableName() {
|
public String getWriteableName() {
|
||||||
|
@ -82,43 +104,79 @@ public class Accuracy implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||||
if (result != null) {
|
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
this.actualField.trySet(actualField);
|
||||||
|
List<AggregationBuilder> aggs = new ArrayList<>();
|
||||||
|
List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
|
||||||
|
if (overallAccuracy.get() == null) {
|
||||||
|
aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
||||||
}
|
}
|
||||||
Script accuracyScript = new Script(buildScript(actualField, predictedField));
|
if (result.get() == null) {
|
||||||
return Tuple.tuple(
|
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
|
||||||
Arrays.asList(
|
aggs.addAll(matrixAggs.v1());
|
||||||
AggregationBuilders.terms(CLASSES_AGG_NAME)
|
pipelineAggs.addAll(matrixAggs.v2());
|
||||||
.field(actualField)
|
}
|
||||||
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
|
return Tuple.tuple(aggs, pipelineAggs);
|
||||||
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
|
|
||||||
Collections.emptyList());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(Aggregations aggs) {
|
public void process(Aggregations aggs) {
|
||||||
if (result != null) {
|
if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||||
return;
|
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
||||||
|
overallAccuracy.set(overallAccuracyAgg.value());
|
||||||
}
|
}
|
||||||
Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
|
matrix.process(aggs);
|
||||||
NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
if (result.get() == null && matrix.getResult().isPresent()) {
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
|
if (matrix.getResult().get().getOtherActualClassCount() > 0) {
|
||||||
for (Terms.Bucket bucket : classesAgg.getBuckets()) {
|
// This means there were more than {@code maxClassesCardinality} buckets.
|
||||||
String actualClass = bucket.getKeyAsString();
|
// We cannot calculate per-class accuracy accurately, so we fail.
|
||||||
long actualClassDocCount = bucket.getDocCount();
|
throw ExceptionsHelper.badRequestException(
|
||||||
NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
|
"Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get());
|
||||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
|
}
|
||||||
|
result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get()));
|
||||||
}
|
}
|
||||||
result = new Result(actualClasses, overallAccuracyAgg.value());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<EvaluationMetricResult> getResult() {
|
public Optional<Result> getResult() {
|
||||||
return Optional.ofNullable(result);
|
return Optional.ofNullable(result.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the per-class accuracy results based on multiclass confusion matrix's result.
|
||||||
|
* Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension.
|
||||||
|
* This method is visible for testing only.
|
||||||
|
*/
|
||||||
|
static List<PerClassResult> computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) {
|
||||||
|
assert matrixResult.getOtherActualClassCount() == 0;
|
||||||
|
// Number of actual classes taken into account
|
||||||
|
int n = matrixResult.getConfusionMatrix().size();
|
||||||
|
// Total number of documents taken into account
|
||||||
|
long totalDocCount =
|
||||||
|
matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum();
|
||||||
|
List<PerClassResult> classes = new ArrayList<>(n);
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
String className = matrixResult.getConfusionMatrix().get(i).getActualClass();
|
||||||
|
// Start with the assumption that all the docs were predicted correctly.
|
||||||
|
long correctDocCount = totalDocCount;
|
||||||
|
for (int j = 0; j < n; ++j) {
|
||||||
|
if (i != j) {
|
||||||
|
// Subtract errors (false negatives)
|
||||||
|
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount();
|
||||||
|
// Subtract errors (false positives)
|
||||||
|
correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix
|
||||||
|
correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount();
|
||||||
|
classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount));
|
||||||
|
}
|
||||||
|
return classes;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
matrix.writeTo(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -132,25 +190,26 @@ public class Accuracy implements EvaluationMetric {
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
return true;
|
Accuracy that = (Accuracy) o;
|
||||||
|
return Objects.equals(this.matrix, that.matrix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hashCode(NAME.getPreferredName());
|
return Objects.hash(matrix);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Result implements EvaluationMetricResult {
|
public static class Result implements EvaluationMetricResult {
|
||||||
|
|
||||||
private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
private static final ParseField CLASSES = new ParseField("classes");
|
||||||
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||||
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||||
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,18 +217,18 @@ public class Accuracy implements EvaluationMetric {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** List of actual classes. */
|
/** List of per-class results. */
|
||||||
private final List<ActualClass> actualClasses;
|
private final List<PerClassResult> classes;
|
||||||
/** Fraction of documents predicted correctly. */
|
/** Fraction of documents for which predicted class equals the actual class. */
|
||||||
private final double overallAccuracy;
|
private final double overallAccuracy;
|
||||||
|
|
||||||
public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
public Result(List<PerClassResult> classes, double overallAccuracy) {
|
||||||
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
|
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
|
||||||
this.overallAccuracy = overallAccuracy;
|
this.overallAccuracy = overallAccuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Result(StreamInput in) throws IOException {
|
public Result(StreamInput in) throws IOException {
|
||||||
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
|
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
|
||||||
this.overallAccuracy = in.readDouble();
|
this.overallAccuracy = in.readDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,8 +242,8 @@ public class Accuracy implements EvaluationMetric {
|
||||||
return NAME.getPreferredName();
|
return NAME.getPreferredName();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<ActualClass> getActualClasses() {
|
public List<PerClassResult> getClasses() {
|
||||||
return actualClasses;
|
return classes;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getOverallAccuracy() {
|
public double getOverallAccuracy() {
|
||||||
|
@ -193,14 +252,14 @@ public class Accuracy implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeList(actualClasses);
|
out.writeList(classes);
|
||||||
out.writeDouble(overallAccuracy);
|
out.writeDouble(overallAccuracy);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
builder.field(CLASSES.getPreferredName(), classes);
|
||||||
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -211,54 +270,47 @@ public class Accuracy implements EvaluationMetric {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
Result that = (Result) o;
|
Result that = (Result) o;
|
||||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
return Objects.equals(this.classes, that.classes)
|
||||||
&& this.overallAccuracy == that.overallAccuracy;
|
&& this.overallAccuracy == that.overallAccuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(actualClasses, overallAccuracy);
|
return Objects.hash(classes, overallAccuracy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class ActualClass implements ToXContentObject, Writeable {
|
public static class PerClassResult implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
|
||||||
private static final ParseField ACCURACY = new ParseField("accuracy");
|
private static final ParseField ACCURACY = new ParseField("accuracy");
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||||
new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
|
||||||
PARSER.declareDouble(constructorArg(), ACCURACY);
|
PARSER.declareDouble(constructorArg(), ACCURACY);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Name of the actual class. */
|
/** Name of the class. */
|
||||||
private final String actualClass;
|
private final String className;
|
||||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
/** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
|
||||||
private final long actualClassDocCount;
|
|
||||||
/** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
|
||||||
private final double accuracy;
|
private final double accuracy;
|
||||||
|
|
||||||
public ActualClass(
|
public PerClassResult(String className, double accuracy) {
|
||||||
String actualClass, long actualClassDocCount, double accuracy) {
|
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
|
||||||
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
|
|
||||||
this.actualClassDocCount = actualClassDocCount;
|
|
||||||
this.accuracy = accuracy;
|
this.accuracy = accuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ActualClass(StreamInput in) throws IOException {
|
public PerClassResult(StreamInput in) throws IOException {
|
||||||
this.actualClass = in.readString();
|
this.className = in.readString();
|
||||||
this.actualClassDocCount = in.readVLong();
|
|
||||||
this.accuracy = in.readDouble();
|
this.accuracy = in.readDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getActualClass() {
|
public String getClassName() {
|
||||||
return actualClass;
|
return className;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getAccuracy() {
|
public double getAccuracy() {
|
||||||
|
@ -267,16 +319,14 @@ public class Accuracy implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeString(actualClass);
|
out.writeString(className);
|
||||||
out.writeVLong(actualClassDocCount);
|
|
||||||
out.writeDouble(accuracy);
|
out.writeDouble(accuracy);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
|
||||||
builder.field(ACCURACY.getPreferredName(), accuracy);
|
builder.field(ACCURACY.getPreferredName(), accuracy);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
|
@ -286,15 +336,14 @@ public class Accuracy implements EvaluationMetric {
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
ActualClass that = (ActualClass) o;
|
PerClassResult that = (PerClassResult) o;
|
||||||
return Objects.equals(this.actualClass, that.actualClass)
|
return Objects.equals(this.className, that.className)
|
||||||
&& this.actualClassDocCount == that.actualClassDocCount
|
|
||||||
&& this.accuracy == that.accuracy;
|
&& this.accuracy == that.accuracy;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
return Objects.hash(className, accuracy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.SetOnce;
|
||||||
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
|
@ -53,13 +55,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
|
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
|
||||||
|
|
||||||
public static final ParseField SIZE = new ParseField("size");
|
public static final ParseField SIZE = new ParseField("size");
|
||||||
|
public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix");
|
||||||
|
|
||||||
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
|
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
|
||||||
|
|
||||||
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
|
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
|
||||||
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
|
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
|
||||||
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
|
new ConstructingObjectParser<>(
|
||||||
|
NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
|
||||||
parser.declareInt(optionalConstructorArg(), SIZE);
|
parser.declareInt(optionalConstructorArg(), SIZE);
|
||||||
|
parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,31 +72,39 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
||||||
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
||||||
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
||||||
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
||||||
private static final String OTHER_BUCKET_KEY = "_other_";
|
private static final String OTHER_BUCKET_KEY = "_other_";
|
||||||
|
private static final String DEFAULT_AGG_NAME_PREFIX = "";
|
||||||
private static final int DEFAULT_SIZE = 10;
|
private static final int DEFAULT_SIZE = 10;
|
||||||
private static final int MAX_SIZE = 1000;
|
private static final int MAX_SIZE = 1000;
|
||||||
|
|
||||||
private final int size;
|
private final int size;
|
||||||
private List<String> topActualClassNames;
|
private final String aggNamePrefix;
|
||||||
private Result result;
|
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
||||||
|
private final SetOnce<Result> result = new SetOnce<>();
|
||||||
|
|
||||||
public MulticlassConfusionMatrix() {
|
public MulticlassConfusionMatrix() {
|
||||||
this((Integer) null);
|
this(null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public MulticlassConfusionMatrix(@Nullable Integer size) {
|
public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) {
|
||||||
if (size != null && (size <= 0 || size > MAX_SIZE)) {
|
if (size != null && (size <= 0 || size > MAX_SIZE)) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
|
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
|
||||||
}
|
}
|
||||||
this.size = size != null ? size : DEFAULT_SIZE;
|
this.size = size != null ? size : DEFAULT_SIZE;
|
||||||
|
this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX;
|
||||||
}
|
}
|
||||||
|
|
||||||
public MulticlassConfusionMatrix(StreamInput in) throws IOException {
|
public MulticlassConfusionMatrix(StreamInput in) throws IOException {
|
||||||
this.size = in.readVInt();
|
this.size = in.readVInt();
|
||||||
|
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
|
this.aggNamePrefix = in.readString();
|
||||||
|
} else {
|
||||||
|
this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -110,30 +123,30 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||||
if (topActualClassNames == null) { // This is step 1
|
if (topActualClassNames.get() == null) { // This is step 1
|
||||||
return Tuple.tuple(
|
return Tuple.tuple(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
|
AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
|
||||||
.field(actualField)
|
.field(actualField)
|
||||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||||
.size(size)),
|
.size(size)),
|
||||||
Collections.emptyList());
|
Collections.emptyList());
|
||||||
}
|
}
|
||||||
if (result == null) { // This is step 2
|
if (result.get() == null) { // This is step 2
|
||||||
KeyedFilter[] keyedFiltersActual =
|
KeyedFilter[] keyedFiltersActual =
|
||||||
topActualClassNames.stream()
|
topActualClassNames.get().stream()
|
||||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||||
.toArray(KeyedFilter[]::new);
|
.toArray(KeyedFilter[]::new);
|
||||||
KeyedFilter[] keyedFiltersPredicted =
|
KeyedFilter[] keyedFiltersPredicted =
|
||||||
topActualClassNames.stream()
|
topActualClassNames.get().stream()
|
||||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||||
.toArray(KeyedFilter[]::new);
|
.toArray(KeyedFilter[]::new);
|
||||||
return Tuple.tuple(
|
return Tuple.tuple(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
|
AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
|
||||||
.field(actualField),
|
.field(actualField),
|
||||||
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
|
AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
|
||||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
|
.subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
|
||||||
.otherBucket(true)
|
.otherBucket(true)
|
||||||
.otherBucketKey(OTHER_BUCKET_KEY))),
|
.otherBucketKey(OTHER_BUCKET_KEY))),
|
||||||
Collections.emptyList());
|
Collections.emptyList());
|
||||||
|
@ -143,18 +156,18 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(Aggregations aggs) {
|
public void process(Aggregations aggs) {
|
||||||
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
||||||
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
|
Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
|
||||||
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
|
topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
|
||||||
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
|
Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
|
||||||
Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
|
Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
|
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
|
||||||
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
||||||
String actualClass = bucket.getKeyAsString();
|
String actualClass = bucket.getKeyAsString();
|
||||||
long actualClassDocCount = bucket.getDocCount();
|
long actualClassDocCount = bucket.getDocCount();
|
||||||
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
|
Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS));
|
||||||
List<PredictedClass> predictedClasses = new ArrayList<>();
|
List<PredictedClass> predictedClasses = new ArrayList<>();
|
||||||
long otherPredictedClassDocCount = 0;
|
long otherPredictedClassDocCount = 0;
|
||||||
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
|
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
|
||||||
|
@ -169,18 +182,25 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
||||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
||||||
}
|
}
|
||||||
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
|
result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String aggName(String aggNameWithoutPrefix) {
|
||||||
|
return aggNamePrefix + aggNameWithoutPrefix;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<EvaluationMetricResult> getResult() {
|
public Optional<Result> getResult() {
|
||||||
return Optional.ofNullable(result);
|
return Optional.ofNullable(result.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeVInt(size);
|
out.writeVInt(size);
|
||||||
|
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
||||||
|
out.writeString(aggNamePrefix);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -196,12 +216,13 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
|
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
|
||||||
return Objects.equals(this.size, that.size);
|
return this.size == that.size
|
||||||
|
&& Objects.equals(this.aggNamePrefix, that.aggNamePrefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(size);
|
return Objects.hash(size, aggNamePrefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Result implements EvaluationMetricResult {
|
public static class Result implements EvaluationMetricResult {
|
||||||
|
@ -335,6 +356,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
return actualClass;
|
return actualClass;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public long getActualClassDocCount() {
|
||||||
|
return actualClassDocCount;
|
||||||
|
}
|
||||||
|
|
||||||
public List<PredictedClass> getPredictedClasses() {
|
public List<PredictedClass> getPredictedClasses() {
|
||||||
return predictedClasses;
|
return predictedClasses;
|
||||||
}
|
}
|
||||||
|
@ -411,6 +436,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||||
return predictedClass;
|
return predictedClass;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public long getCount() {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
out.writeString(predictedClass);
|
out.writeString(predictedClass);
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.SetOnce;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
@ -77,9 +78,9 @@ public class Precision implements EvaluationMetric {
|
||||||
|
|
||||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||||
|
|
||||||
private String actualField;
|
private final SetOnce<String> actualField = new SetOnce<>();
|
||||||
private List<String> topActualClassNames;
|
private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
|
||||||
private EvaluationMetricResult result;
|
private final SetOnce<Result> result = new SetOnce<>();
|
||||||
|
|
||||||
public Precision() {}
|
public Precision() {}
|
||||||
|
|
||||||
|
@ -98,8 +99,8 @@ public class Precision implements EvaluationMetric {
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||||
this.actualField = actualField;
|
this.actualField.trySet(actualField);
|
||||||
if (topActualClassNames == null) { // This is step 1
|
if (topActualClassNames.get() == null) { // This is step 1
|
||||||
return Tuple.tuple(
|
return Tuple.tuple(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
|
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
|
||||||
|
@ -108,9 +109,9 @@ public class Precision implements EvaluationMetric {
|
||||||
.size(MAX_CLASSES_CARDINALITY)),
|
.size(MAX_CLASSES_CARDINALITY)),
|
||||||
Collections.emptyList());
|
Collections.emptyList());
|
||||||
}
|
}
|
||||||
if (result == null) { // This is step 2
|
if (result.get() == null) { // This is step 2
|
||||||
KeyedFilter[] keyedFiltersPredicted =
|
KeyedFilter[] keyedFiltersPredicted =
|
||||||
topActualClassNames.stream()
|
topActualClassNames.get().stream()
|
||||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||||
.toArray(KeyedFilter[]::new);
|
.toArray(KeyedFilter[]::new);
|
||||||
Script script = buildScript(actualField, predictedField);
|
Script script = buildScript(actualField, predictedField);
|
||||||
|
@ -128,18 +129,18 @@ public class Precision implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(Aggregations aggs) {
|
public void process(Aggregations aggs) {
|
||||||
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
|
if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
|
||||||
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
|
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
|
||||||
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
|
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
|
||||||
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
||||||
// We cannot calculate average precision accurately, so we fail.
|
// We cannot calculate average precision accurately, so we fail.
|
||||||
throw ExceptionsHelper.badRequestException(
|
throw ExceptionsHelper.badRequestException(
|
||||||
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
|
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get());
|
||||||
}
|
}
|
||||||
topActualClassNames =
|
topActualClassNames.set(
|
||||||
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
|
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
if (result == null &&
|
if (result.get() == null &&
|
||||||
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
|
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
|
||||||
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||||
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
|
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
|
||||||
|
@ -153,13 +154,13 @@ public class Precision implements EvaluationMetric {
|
||||||
classes.add(new PerClassResult(className, precision));
|
classes.add(new PerClassResult(className, precision));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result = new Result(classes, avgPrecisionAgg.value());
|
result.set(new Result(classes, avgPrecisionAgg.value()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<EvaluationMetricResult> getResult() {
|
public Optional<Result> getResult() {
|
||||||
return Optional.ofNullable(result);
|
return Optional.ofNullable(result.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.SetOnce;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
@ -71,8 +72,8 @@ public class Recall implements EvaluationMetric {
|
||||||
|
|
||||||
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
private static final int MAX_CLASSES_CARDINALITY = 1000;
|
||||||
|
|
||||||
private String actualField;
|
private final SetOnce<String> actualField = new SetOnce<>();
|
||||||
private EvaluationMetricResult result;
|
private final SetOnce<Result> result = new SetOnce<>();
|
||||||
|
|
||||||
public Recall() {}
|
public Recall() {}
|
||||||
|
|
||||||
|
@ -91,8 +92,8 @@ public class Recall implements EvaluationMetric {
|
||||||
@Override
|
@Override
|
||||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||||
this.actualField = actualField;
|
this.actualField.trySet(actualField);
|
||||||
if (result != null) {
|
if (result.get() != null) {
|
||||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||||
}
|
}
|
||||||
Script script = buildScript(actualField, predictedField);
|
Script script = buildScript(actualField, predictedField);
|
||||||
|
@ -110,7 +111,7 @@ public class Recall implements EvaluationMetric {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(Aggregations aggs) {
|
public void process(Aggregations aggs) {
|
||||||
if (result == null &&
|
if (result.get() == null &&
|
||||||
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
|
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
|
||||||
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||||
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
|
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
|
||||||
|
@ -118,7 +119,7 @@ public class Recall implements EvaluationMetric {
|
||||||
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
// This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
|
||||||
// We cannot calculate average recall accurately, so we fail.
|
// We cannot calculate average recall accurately, so we fail.
|
||||||
throw ExceptionsHelper.badRequestException(
|
throw ExceptionsHelper.badRequestException(
|
||||||
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
|
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get());
|
||||||
}
|
}
|
||||||
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
|
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
|
||||||
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
|
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
|
||||||
|
@ -127,13 +128,13 @@ public class Recall implements EvaluationMetric {
|
||||||
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
|
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
|
||||||
classes.add(new PerClassResult(className, recallAgg.value()));
|
classes.add(new PerClassResult(className, recallAgg.value()));
|
||||||
}
|
}
|
||||||
result = new Result(classes, avgRecallAgg.value());
|
result.set(new Result(classes, avgRecallAgg.value()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<EvaluationMetricResult> getResult() {
|
public Optional<Result> getResult() {
|
||||||
return Optional.ofNullable(result);
|
return Optional.ofNullable(result.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -8,9 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -22,13 +22,13 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result>
|
||||||
public static Result createRandom() {
|
public static Result createRandom() {
|
||||||
int numClasses = randomIntBetween(2, 100);
|
int numClasses = randomIntBetween(2, 100);
|
||||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||||
for (int i = 0; i < numClasses; i++) {
|
for (int i = 0; i < numClasses; i++) {
|
||||||
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
double accuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
|
classes.add(new PerClassResult(classNames.get(i), accuracy));
|
||||||
}
|
}
|
||||||
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
|
||||||
return new Result(actualClasses, overallAccuracy);
|
return new Result(classes, overallAccuracy);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -5,17 +5,27 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||||
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
|
@ -46,52 +56,114 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
|
||||||
|
|
||||||
public void testProcess() {
|
public void testProcess() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockTerms("classification_classes"),
|
mockTerms(
|
||||||
mockSingleValue("classification_overall_accuracy", 0.8123),
|
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
Arrays.asList(
|
||||||
));
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
|
100L),
|
||||||
|
mockFilters(
|
||||||
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket(
|
||||||
|
"dog",
|
||||||
|
30,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||||
|
mockFiltersBucket(
|
||||||
|
"cat",
|
||||||
|
70,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||||
|
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
|
||||||
|
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||||
|
|
||||||
Accuracy accuracy = new Accuracy();
|
Accuracy accuracy = new Accuracy();
|
||||||
accuracy.process(aggs);
|
accuracy.process(aggs);
|
||||||
|
|
||||||
assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(Collections.emptyList(), 0.8123)));
|
assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||||
|
|
||||||
|
Result result = accuracy.getResult().get();
|
||||||
|
assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||||
|
assertThat(
|
||||||
|
result.getClasses(),
|
||||||
|
equalTo(
|
||||||
|
Arrays.asList(
|
||||||
|
new PerClassResult("dog", 0.5),
|
||||||
|
new PerClassResult("cat", 0.5))));
|
||||||
|
assertThat(result.getOverallAccuracy(), equalTo(0.5));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testProcess_GivenMissingAgg() {
|
public void testProcess_GivenCardinalityTooHigh() {
|
||||||
{
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
mockTerms(
|
||||||
mockTerms("classification_classes"),
|
"accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
Arrays.asList(
|
||||||
));
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
Accuracy accuracy = new Accuracy();
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
100L),
|
||||||
}
|
mockFilters(
|
||||||
{
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Arrays.asList(
|
||||||
mockSingleValue("classification_overall_accuracy", 0.8123),
|
mockFiltersBucket(
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
"dog",
|
||||||
));
|
30,
|
||||||
Accuracy accuracy = new Accuracy();
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
}
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||||
|
mockFiltersBucket(
|
||||||
|
"cat",
|
||||||
|
70,
|
||||||
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
|
"accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
|
Arrays.asList(
|
||||||
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||||
|
mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
|
||||||
|
mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
|
||||||
|
|
||||||
|
Accuracy accuracy = new Accuracy();
|
||||||
|
accuracy.aggs("foo", "bar");
|
||||||
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
|
||||||
|
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testProcess_GivenAggOfWrongType() {
|
public void testComputePerClassAccuracy() {
|
||||||
{
|
assertThat(
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Accuracy.computePerClassAccuracy(
|
||||||
mockTerms("classification_classes"),
|
new MulticlassConfusionMatrix.Result(
|
||||||
mockTerms("classification_overall_accuracy")
|
Arrays.asList(
|
||||||
));
|
new MulticlassConfusionMatrix.ActualClass("A", 14, Arrays.asList(
|
||||||
Accuracy accuracy = new Accuracy();
|
new MulticlassConfusionMatrix.PredictedClass("A", 1),
|
||||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
new MulticlassConfusionMatrix.PredictedClass("B", 6),
|
||||||
}
|
new MulticlassConfusionMatrix.PredictedClass("C", 4)
|
||||||
{
|
), 3L),
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
new MulticlassConfusionMatrix.ActualClass("B", 20, Arrays.asList(
|
||||||
mockSingleValue("classification_classes", 1.0),
|
new MulticlassConfusionMatrix.PredictedClass("A", 5),
|
||||||
mockSingleValue("classification_overall_accuracy", 0.8123)
|
new MulticlassConfusionMatrix.PredictedClass("B", 3),
|
||||||
));
|
new MulticlassConfusionMatrix.PredictedClass("C", 9)
|
||||||
Accuracy accuracy = new Accuracy();
|
), 3L),
|
||||||
expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
|
new MulticlassConfusionMatrix.ActualClass("C", 17, Arrays.asList(
|
||||||
}
|
new MulticlassConfusionMatrix.PredictedClass("A", 8),
|
||||||
|
new MulticlassConfusionMatrix.PredictedClass("B", 2),
|
||||||
|
new MulticlassConfusionMatrix.PredictedClass("C", 7)
|
||||||
|
), 0L)),
|
||||||
|
0)),
|
||||||
|
equalTo(
|
||||||
|
Arrays.asList(
|
||||||
|
new Accuracy.PerClassResult("A", 25.0 / 51), // 13 false positives, 13 false negatives
|
||||||
|
new Accuracy.PerClassResult("B", 26.0 / 51), // 8 false positives, 17 false negatives
|
||||||
|
new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() {
|
||||||
|
expectThrows(
|
||||||
|
AssertionError.class,
|
||||||
|
() -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(Collections.emptyList(), 1)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -56,20 +57,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
|
|
||||||
public static MulticlassConfusionMatrix createRandom() {
|
public static MulticlassConfusionMatrix createRandom() {
|
||||||
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
|
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
|
||||||
return new MulticlassConfusionMatrix(size);
|
return new MulticlassConfusionMatrix(size, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_SizeValidationFailures() {
|
public void testConstructor_SizeValidationFailures() {
|
||||||
{
|
{
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
|
ElasticsearchStatusException e =
|
||||||
|
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null));
|
||||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
|
ElasticsearchStatusException e =
|
||||||
|
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null));
|
||||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
|
ElasticsearchStatusException e =
|
||||||
|
expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null));
|
||||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,36 +88,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockTerms(
|
mockTerms(
|
||||||
"multiclass_confusion_matrix_step_1_by_actual_class",
|
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
0L),
|
0L),
|
||||||
mockFilters(
|
mockFilters(
|
||||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket(
|
mockFiltersBucket(
|
||||||
"dog",
|
"dog",
|
||||||
30,
|
30,
|
||||||
new Aggregations(Arrays.asList(mockFilters(
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||||
mockFiltersBucket(
|
mockFiltersBucket(
|
||||||
"cat",
|
"cat",
|
||||||
70,
|
70,
|
||||||
new Aggregations(Arrays.asList(mockFilters(
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
|
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
|
||||||
|
|
||||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||||
confusionMatrix.process(aggs);
|
confusionMatrix.process(aggs);
|
||||||
|
|
||||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
Result result = confusionMatrix.getResult().get();
|
||||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
result.getConfusionMatrix(),
|
result.getConfusionMatrix(),
|
||||||
equalTo(
|
equalTo(
|
||||||
|
@ -126,36 +130,36 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||||
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockTerms(
|
mockTerms(
|
||||||
"multiclass_confusion_matrix_step_1_by_actual_class",
|
MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||||
100L),
|
100L),
|
||||||
mockFilters(
|
mockFilters(
|
||||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket(
|
mockFiltersBucket(
|
||||||
"dog",
|
"dog",
|
||||||
30,
|
30,
|
||||||
new Aggregations(Arrays.asList(mockFilters(
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||||
mockFiltersBucket(
|
mockFiltersBucket(
|
||||||
"cat",
|
"cat",
|
||||||
85,
|
85,
|
||||||
new Aggregations(Arrays.asList(mockFilters(
|
new Aggregations(Arrays.asList(mockFilters(
|
||||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
|
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
|
||||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
|
mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
|
||||||
|
|
||||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
|
||||||
confusionMatrix.process(aggs);
|
confusionMatrix.process(aggs);
|
||||||
|
|
||||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
Result result = confusionMatrix.getResult().get();
|
||||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
result.getConfusionMatrix(),
|
result.getConfusionMatrix(),
|
||||||
equalTo(
|
equalTo(
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||||
|
@ -22,6 +23,8 @@ import org.junit.Before;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
import static org.hamcrest.Matchers.contains;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
@ -53,10 +56,28 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
|
||||||
|
|
||||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
|
||||||
assertThat(
|
assertThat(
|
||||||
evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
|
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||||
equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
contains(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testEvaluate_AllMetrics() {
|
||||||
|
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||||
|
evaluateDataFrame(
|
||||||
|
ANIMALS_DATA_INDEX,
|
||||||
|
new Classification(
|
||||||
|
ANIMAL_NAME_FIELD,
|
||||||
|
ANIMAL_NAME_PREDICTION_FIELD,
|
||||||
|
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||||
|
|
||||||
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||||
|
assertThat(
|
||||||
|
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||||
|
contains(
|
||||||
|
Accuracy.NAME.getPreferredName(),
|
||||||
|
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||||
|
Precision.NAME.getPreferredName(),
|
||||||
|
Recall.NAME.getPreferredName()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate_Accuracy_KeywordField() {
|
public void testEvaluate_Accuracy_KeywordField() {
|
||||||
|
@ -70,14 +91,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
accuracyResult.getActualClasses(),
|
accuracyResult.getClasses(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
new Accuracy.ActualClass("ant", 15, 1.0 / 15),
|
new Accuracy.PerClassResult("ant", 47.0 / 75),
|
||||||
new Accuracy.ActualClass("cat", 15, 1.0 / 15),
|
new Accuracy.PerClassResult("cat", 47.0 / 75),
|
||||||
new Accuracy.ActualClass("dog", 15, 1.0 / 15),
|
new Accuracy.PerClassResult("dog", 47.0 / 75),
|
||||||
new Accuracy.ActualClass("fox", 15, 1.0 / 15),
|
new Accuracy.PerClassResult("fox", 47.0 / 75),
|
||||||
new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
|
new Accuracy.PerClassResult("mouse", 47.0 / 75))));
|
||||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,13 +113,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
accuracyResult.getActualClasses(),
|
accuracyResult.getClasses(),
|
||||||
equalTo(Arrays.asList(
|
equalTo(
|
||||||
new Accuracy.ActualClass("1", 15, 1.0 / 15),
|
Arrays.asList(
|
||||||
new Accuracy.ActualClass("2", 15, 2.0 / 15),
|
new Accuracy.PerClassResult("1", 57.0 / 75),
|
||||||
new Accuracy.ActualClass("3", 15, 3.0 / 15),
|
new Accuracy.PerClassResult("2", 54.0 / 75),
|
||||||
new Accuracy.ActualClass("4", 15, 4.0 / 15),
|
new Accuracy.PerClassResult("3", 51.0 / 75),
|
||||||
new Accuracy.ActualClass("5", 15, 5.0 / 15))));
|
new Accuracy.PerClassResult("4", 48.0 / 75),
|
||||||
|
new Accuracy.PerClassResult("5", 45.0 / 75))));
|
||||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
|
assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,10 +135,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||||
assertThat(
|
assertThat(
|
||||||
accuracyResult.getActualClasses(),
|
accuracyResult.getClasses(),
|
||||||
equalTo(Arrays.asList(
|
equalTo(
|
||||||
new Accuracy.ActualClass("true", 45, 27.0 / 45),
|
Arrays.asList(
|
||||||
new Accuracy.ActualClass("false", 30, 18.0 / 30))));
|
new Accuracy.PerClassResult("false", 18.0 / 30),
|
||||||
|
new Accuracy.PerClassResult("true", 27.0 / 45))));
|
||||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,7 +275,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||||
evaluateDataFrame(
|
evaluateDataFrame(
|
||||||
ANIMALS_DATA_INDEX,
|
ANIMALS_DATA_INDEX,
|
||||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
|
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3, null))));
|
||||||
|
|
||||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||||
|
|
|
@ -466,12 +466,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
{ // Accuracy
|
{ // Accuracy
|
||||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||||
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
|
||||||
List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses();
|
for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) {
|
||||||
assertThat(
|
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
|
||||||
actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()),
|
assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
|
||||||
equalTo(dependentVariableValuesAsStrings));
|
}
|
||||||
actualClasses.forEach(
|
|
||||||
actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{ // MulticlassConfusionMatrix
|
{ // MulticlassConfusionMatrix
|
||||||
|
|
|
@ -620,16 +620,13 @@ setup:
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
classification.accuracy:
|
classification.accuracy:
|
||||||
actual_classes:
|
classes:
|
||||||
- actual_class: "cat"
|
- class_name: "cat"
|
||||||
actual_class_doc_count: 3
|
accuracy: 0.625 # 5 out of 8
|
||||||
accuracy: 0.6666666666666666 # 2 out of 3
|
- class_name: "dog"
|
||||||
- actual_class: "dog"
|
accuracy: 0.75 # 6 out of 8
|
||||||
actual_class_doc_count: 3
|
- class_name: "mouse"
|
||||||
accuracy: 0.6666666666666666 # 2 out of 3
|
accuracy: 0.875 # 7 out of 8
|
||||||
- actual_class: "mouse"
|
|
||||||
actual_class_doc_count: 2
|
|
||||||
accuracy: 0.5 # 1 out of 2
|
|
||||||
overall_accuracy: 0.625 # 5 out of 8
|
overall_accuracy: 0.625 # 5 out of 8
|
||||||
---
|
---
|
||||||
"Test classification precision":
|
"Test classification precision":
|
||||||
|
|
Loading…
Reference in New Issue