[7.x] Change format of MulticlassConfusionMatrix result to be more self-explanatory (#48174) (#48294)

This commit is contained in:
Przemysław Witek 2019-10-21 22:07:19 +02:00 committed by GitHub
parent 178204703a
commit 2db2b945ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 714 additions and 277 deletions

View File

@ -21,17 +21,17 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.List;
import java.util.Objects;
import java.util.TreeMap;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
@ -97,32 +97,28 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
public static class Result implements EvaluationMetric.Result {
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));
static {
PARSER.declareObject(
constructorArg(),
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
// Immutable
private final Map<String, Map<String, Long>> confusionMatrix;
private final long otherClassesCount;
private final List<ActualClass> confusionMatrix;
private final Long otherActualClassCount;
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
this.otherClassesCount = otherClassesCount;
public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
this.otherActualClassCount = otherActualClassCount;
}
@Override
@ -130,19 +126,23 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
return NAME;
}
public Map<String, Map<String, Long>> getConfusionMatrix() {
public List<ActualClass> getConfusionMatrix() {
return confusionMatrix;
}
public long getOtherClassesCount() {
return otherClassesCount;
public Long getOtherActualClassCount() {
return otherActualClassCount;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (confusionMatrix != null) {
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
}
if (otherActualClassCount != null) {
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
}
builder.endObject();
return builder;
}
@ -153,12 +153,140 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount;
&& Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
}
@Override
public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount);
return Objects.hash(confusionMatrix, otherActualClassCount);
}
}
public static class ActualClass implements ToXContentObject {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_actual_class",
true,
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));
static {
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
}
private final String actualClass;
private final Long actualClassDocCount;
private final List<PredictedClass> predictedClasses;
private final Long otherPredictedClassDocCount;
public ActualClass(@Nullable String actualClass,
@Nullable Long actualClassDocCount,
@Nullable List<PredictedClass> predictedClasses,
@Nullable Long otherPredictedClassDocCount) {
this.actualClass = actualClass;
this.actualClassDocCount = actualClassDocCount;
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (actualClass != null) {
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
}
if (actualClassDocCount != null) {
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
}
if (predictedClasses != null) {
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
}
if (otherPredictedClassDocCount != null) {
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
&& Objects.equals(this.predictedClasses, that.predictedClasses)
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
public static class PredictedClass implements ToXContentObject {
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
private static final ParseField COUNT = new ParseField("count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));
static {
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
PARSER.declareLong(optionalConstructorArg(), COUNT);
}
private final String predictedClass;
private final Long count;
public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
this.predictedClass = predictedClass;
this.count = count;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (predictedClass != null) {
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
}
if (count != null) {
builder.field(COUNT.getPreferredName(), count);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PredictedClass that = (PredictedClass) o;
return Objects.equals(this.predictedClass, that.predictedClass)
&& Objects.equals(this.count, that.count);
}
@Override
public int hashCode() {
return Objects.hash(predictedClass, count);
}
}
}

View File

@ -127,6 +127,8 @@ import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
@ -1807,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "horse", "cat"));
.add(docForClassification(indexName, "ant", "cat"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@ -1827,22 +1829,26 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 3L);
expectedConfusionMatrix.get("cat").put("dog", 1L);
expectedConfusionMatrix.get("cat").put("horse", 0L);
expectedConfusionMatrix.get("cat").put("_other_", 1L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 1L);
expectedConfusionMatrix.get("dog").put("dog", 3L);
expectedConfusionMatrix.get("dog").put("horse", 0L);
expectedConfusionMatrix.put("horse", new HashMap<>());
expectedConfusionMatrix.get("horse").put("cat", 1L);
expectedConfusionMatrix.get("horse").put("dog", 0L);
expectedConfusionMatrix.get("horse").put("horse", 0L);
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
assertThat(
mcmResult.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass(
"ant",
1L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
0L),
new ActualClass(
"cat",
5L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1L),
new ActualClass(
"dog",
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
}
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
EvaluateDataFrameRequest evaluateDataFrameRequest =
@ -1859,16 +1865,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 3L);
expectedConfusionMatrix.get("cat").put("dog", 1L);
expectedConfusionMatrix.get("cat").put("_other_", 1L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 1L);
expectedConfusionMatrix.get("dog").put("dog", 3L);
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
assertThat(
mcmResult.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
)));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
}
}

View File

@ -142,6 +142,8 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -3355,33 +3357,31 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
// end::evaluate-data-frame-results-classification
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,
equalTo(
new HashMap<String, Map<String, Long>>() {{
put("cat", new HashMap<String, Long>() {{
put("cat", 3L);
put("dog", 1L);
put("ant", 0L);
put("_other_", 1L);
}});
put("dog", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 3L);
put("ant", 0L);
}});
put("ant", new HashMap<String, Long>() {{
put("cat", 1L);
put("dog", 0L);
put("ant", 0L);
}});
}}));
assertThat(otherClassesCount, equalTo(0L));
Arrays.asList(
new ActualClass(
"ant",
1L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
0L),
new ActualClass(
"cat",
5L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1L),
new ActualClass(
"dog",
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(otherActualClassCount, equalTo(0L));
}
}

View File

@ -19,19 +19,21 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric.Result> {
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<Result> {
@Override
protected NamedXContentRegistry xContentRegistry() {
@ -39,26 +41,28 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
}
@Override
protected MulticlassConfusionMatrixMetric.Result createTestInstance() {
protected Result createTestInstance() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
Map<String, Long> row = new TreeMap<>();
confusionMatrix.put(classNames.get(i), row);
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
for (int j = 0; j < numClasses; j++) {
if (randomBoolean()) {
row.put(classNames.get(i), randomNonNegativeLong());
predictedClasses.add(new PredictedClass(classNames.get(j), randomBoolean() ? randomNonNegativeLong() : null));
}
actualClasses.add(
new ActualClass(
classNames.get(i),
randomBoolean() ? randomNonNegativeLong() : null,
predictedClasses,
randomBoolean() ? randomNonNegativeLong() : null));
}
}
long otherClassesCount = randomNonNegativeLong();
return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount);
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
}
@Override
protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrixMetric.Result.fromXContent(parser);
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);
}
@Override

View File

@ -9,7 +9,9 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
@ -26,14 +28,14 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors;
import static java.util.Comparator.comparing;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -112,18 +114,19 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
.size(size));
}
if (result == null) { // This is step 2
KeyedFilter[] keyedFilters =
KeyedFilter[] keyedFiltersActual =
topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
return Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
.field(actualField),
AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters)
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY)));
}
@ -134,26 +137,31 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList());
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
}
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
Map<String, Map<String, Long>> counts = new TreeMap<>();
for (Terms.Bucket bucket : termsAgg.getBuckets()) {
Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString();
Map<String, Long> subCounts = new TreeMap<>();
counts.put(actualClass, subCounts);
long actualClassDocCount = bucket.getDocCount();
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
List<PredictedClass> predictedClasses = new ArrayList<>();
long otherPredictedClassDocCount = 0;
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
String predictedClass = subBucket.getKeyAsString();
Long docCount = subBucket.getDocCount();
if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) {
subCounts.put(predictedClass, docCount);
long docCount = subBucket.getDocCount();
if (OTHER_BUCKET_KEY.equals(predictedClass)) {
otherPredictedClassDocCount = docCount;
} else {
predictedClasses.add(new PredictedClass(predictedClass, docCount));
}
}
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
}
result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size);
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
}
}
@ -191,37 +199,35 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
public static class Result implements EvaluationMetricResult {
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (long) a[1]));
static {
PARSER.declareObject(
constructorArg(),
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
// Immutable
private final Map<String, Map<String, Long>> confusionMatrix;
private final long otherClassesCount;
/** List of actual classes. */
private final List<ActualClass> actualClasses;
/** Number of actual classes that were not included in the confusion matrix because there were too many of them. */
private final long otherActualClassCount;
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
this.otherClassesCount = otherClassesCount;
public Result(List<ActualClass> actualClasses, long otherActualClassCount) {
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX));
this.otherActualClassCount = requireNonNegative(otherActualClassCount, OTHER_ACTUAL_CLASS_COUNT);
}
public Result(StreamInput in) throws IOException {
this.confusionMatrix = Collections.unmodifiableMap(
in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong)));
this.otherClassesCount = in.readLong();
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
this.otherActualClassCount = in.readVLong();
}
@Override
@ -234,28 +240,25 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
return NAME.getPreferredName();
}
public Map<String, Map<String, Long>> getConfusionMatrix() {
return confusionMatrix;
public List<ActualClass> getConfusionMatrix() {
return actualClasses;
}
public long getOtherClassesCount() {
return otherClassesCount;
public long getOtherActualClassCount() {
return otherActualClassCount;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeMap(
confusionMatrix,
StreamOutput::writeString,
(out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong));
out.writeLong(otherClassesCount);
out.writeList(actualClasses);
out.writeVLong(otherActualClassCount);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
builder.field(CONFUSION_MATRIX.getPreferredName(), actualClasses);
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
builder.endObject();
return builder;
}
@ -265,13 +268,163 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount;
return Objects.equals(this.actualClasses, that.actualClasses)
&& this.otherActualClassCount == that.otherActualClassCount;
}
@Override
public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount);
return Objects.hash(actualClasses, otherActualClassCount);
}
}
public static class ActualClass implements ToXContentObject, Writeable {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_actual_class",
true,
a -> new ActualClass((String) a[0], (long) a[1], (List<PredictedClass>) a[2], (long) a[3]));
static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
}
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** List of predicted classes. */
private final List<PredictedClass> predictedClasses;
/** Number of documents that were not predicted as any of the {@code predictedClasses}. */
private final long otherPredictedClassDocCount;
public ActualClass(
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
this.actualClassDocCount = requireNonNegative(actualClassDocCount, ACTUAL_CLASS_DOC_COUNT);
this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES));
this.otherPredictedClassDocCount = requireNonNegative(otherPredictedClassDocCount, OTHER_PREDICTED_CLASS_DOC_COUNT);
}
public ActualClass(StreamInput in) throws IOException {
this.actualClass = in.readString();
this.actualClassDocCount = in.readVLong();
this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new));
this.otherPredictedClassDocCount = in.readVLong();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualClass);
out.writeVLong(actualClassDocCount);
out.writeList(predictedClasses);
out.writeVLong(otherPredictedClassDocCount);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
&& Objects.equals(this.predictedClasses, that.predictedClasses)
&& this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
}
}
public static class PredictedClass implements ToXContentObject, Writeable {
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
private static final ParseField COUNT = new ParseField("count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1]));
static {
PARSER.declareString(constructorArg(), PREDICTED_CLASS);
PARSER.declareLong(constructorArg(), COUNT);
}
private final String predictedClass;
private final long count;
public PredictedClass(String predictedClass, long count) {
this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS);
this.count = requireNonNegative(count, COUNT);
}
public PredictedClass(StreamInput in) throws IOException {
this.predictedClass = in.readString();
this.count = in.readVLong();
}
public String getPredictedClass() {
return predictedClass;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(predictedClass);
out.writeVLong(count);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
builder.field(COUNT.getPreferredName(), count);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PredictedClass that = (PredictedClass) o;
return Objects.equals(this.predictedClass, that.predictedClass)
&& this.count == that.count;
}
@Override
public int hashCode() {
return Objects.hash(predictedClass, count);
}
}
private static long requireNonNegative(long value, ParseField field) {
if (value < 0) {
throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value);
}
return value;
}
}

View File

@ -5,50 +5,53 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix.Result> {
import static org.hamcrest.Matchers.equalTo;
public static MulticlassConfusionMatrix.Result createRandom() {
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<Result> {
public static Result createRandom() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
Map<String, Long> row = new TreeMap<>();
confusionMatrix.put(classNames.get(i), row);
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
for (int j = 0; j < numClasses; j++) {
if (randomBoolean()) {
row.put(classNames.get(i), randomNonNegativeLong());
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
}
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong()));
}
}
long otherClassesCount = randomNonNegativeLong();
return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount);
return new Result(actualClasses, randomNonNegativeLong());
}
@Override
protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrix.Result.fromXContent(parser);
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);
}
@Override
protected MulticlassConfusionMatrix.Result createTestInstance() {
protected Result createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<MulticlassConfusionMatrix.Result> instanceReader() {
return MulticlassConfusionMatrix.Result::new;
protected Writeable.Reader<Result> instanceReader() {
return Result::new;
}
@Override
@ -61,4 +64,67 @@ public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTes
// allow unknown fields in the root of the object only
return field -> !field.isEmpty();
}
public void testConstructor_ValidationFailures() {
{
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new Result(null, 0));
assertThat(e.getMessage(), equalTo("[confusion_matrix] must not be null."));
}
{
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new Result(Collections.emptyList(), -1));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[other_actual_class_count] must be >= 0, was: -1"));
}
{
IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class,
() -> new Result(Collections.singletonList(new ActualClass(null, 0, Collections.emptyList(), 0)), 0));
assertThat(e.getMessage(), equalTo("[actual_class] must not be null."));
}
{
ElasticsearchException e =
expectThrows(
ElasticsearchException.class,
() -> new Result(Collections.singletonList(new ActualClass("actual_class", -1, Collections.emptyList(), 0)), 0));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[actual_class_doc_count] must be >= 0, was: -1"));
}
{
IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class,
() -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, null, 0)), 0));
assertThat(e.getMessage(), equalTo("[predicted_classes] must not be null."));
}
{
ElasticsearchException e =
expectThrows(
ElasticsearchException.class,
() -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, Collections.emptyList(), -1)), 0));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[other_predicted_class_doc_count] must be >= 0, was: -1"));
}
{
IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class,
() -> new Result(
Collections.singletonList(
new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass(null, 0)), 0)),
0));
assertThat(e.getMessage(), equalTo("[predicted_class] must not be null."));
}
{
ElasticsearchException e =
expectThrows(
ElasticsearchException.class,
() -> new Result(
Collections.singletonList(
new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass("predicted_class", -1)), 0)),
0));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[count] must be >= 0, was: -1"));
}
}
}

View File

@ -14,13 +14,13 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.hamcrest.Matchers.empty;
@ -88,22 +88,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
0L),
mockTerms(
mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class",
Arrays.asList(
mockTermsBucket(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockTermsBucket(
mockFiltersBucket(
"cat",
70,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))),
0L),
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
@ -112,15 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 10L);
expectedConfusionMatrix.get("dog").put("dog", 20L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 30L);
expectedConfusionMatrix.get("cat").put("dog", 40L);
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(result.getOtherClassesCount(), equalTo(0L));
assertThat(
result.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0),
new ActualClass("cat", 70, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0))));
assertThat(result.getOtherActualClassCount(), equalTo(0L));
}
public void testEvaluate_OtherClassesCountGreaterThanZero() {
@ -131,22 +130,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L),
mockTerms(
mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class",
Arrays.asList(
mockTermsBucket(
mockFiltersBucket(
"dog",
30,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockTermsBucket(
mockFiltersBucket(
"cat",
85,
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))),
100L),
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
@ -155,16 +155,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 10L);
expectedConfusionMatrix.get("dog").put("dog", 20L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 30L);
expectedConfusionMatrix.get("cat").put("dog", 40L);
expectedConfusionMatrix.get("cat").put("_other_", 15L);
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(result.getOtherClassesCount(), equalTo(3L));
assertThat(
result.getConfusionMatrix(),
equalTo(
Arrays.asList(
new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0),
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
assertThat(result.getOtherActualClassCount(), equalTo(3L));
}
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
@ -175,9 +172,9 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
return aggregation;
}
private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) {
private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(actualClass);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
@ -189,9 +186,15 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
return aggregation;
}
private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) {
private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters.Bucket mockFiltersBucket(String key, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(predictedClass);
when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getDocCount()).thenReturn(docCount);
return bucket;
}

View File

@ -12,13 +12,13 @@ import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.junit.After;
import org.junit.Before;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
@ -53,39 +53,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("ant", new HashMap<>());
expectedConfusionMatrix.get("ant").put("ant", 1L);
expectedConfusionMatrix.get("ant").put("cat", 4L);
expectedConfusionMatrix.get("ant").put("dog", 3L);
expectedConfusionMatrix.get("ant").put("fox", 2L);
expectedConfusionMatrix.get("ant").put("mouse", 5L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("ant", 3L);
expectedConfusionMatrix.get("cat").put("cat", 1L);
expectedConfusionMatrix.get("cat").put("dog", 5L);
expectedConfusionMatrix.get("cat").put("fox", 4L);
expectedConfusionMatrix.get("cat").put("mouse", 2L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("ant", 4L);
expectedConfusionMatrix.get("dog").put("cat", 2L);
expectedConfusionMatrix.get("dog").put("dog", 1L);
expectedConfusionMatrix.get("dog").put("fox", 5L);
expectedConfusionMatrix.get("dog").put("mouse", 3L);
expectedConfusionMatrix.put("fox", new HashMap<>());
expectedConfusionMatrix.get("fox").put("ant", 5L);
expectedConfusionMatrix.get("fox").put("cat", 3L);
expectedConfusionMatrix.get("fox").put("dog", 2L);
expectedConfusionMatrix.get("fox").put("fox", 1L);
expectedConfusionMatrix.get("fox").put("mouse", 4L);
expectedConfusionMatrix.put("mouse", new HashMap<>());
expectedConfusionMatrix.get("mouse").put("ant", 2L);
expectedConfusionMatrix.get("mouse").put("cat", 5L);
expectedConfusionMatrix.get("mouse").put("dog", 4L);
expectedConfusionMatrix.get("mouse").put("fox", 3L);
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
assertThat(
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new ActualClass("ant",
15,
Arrays.asList(
new PredictedClass("ant", 1L),
new PredictedClass("cat", 4L),
new PredictedClass("dog", 3L),
new PredictedClass("fox", 2L),
new PredictedClass("mouse", 5L)),
0),
new ActualClass("cat",
15,
Arrays.asList(
new PredictedClass("ant", 3L),
new PredictedClass("cat", 1L),
new PredictedClass("dog", 5L),
new PredictedClass("fox", 4L),
new PredictedClass("mouse", 2L)),
0),
new ActualClass("dog",
15,
Arrays.asList(
new PredictedClass("ant", 4L),
new PredictedClass("cat", 2L),
new PredictedClass("dog", 1L),
new PredictedClass("fox", 5L),
new PredictedClass("mouse", 3L)),
0),
new ActualClass("fox",
15,
Arrays.asList(
new PredictedClass("ant", 5L),
new PredictedClass("cat", 3L),
new PredictedClass("dog", 2L),
new PredictedClass("fox", 1L),
new PredictedClass("mouse", 4L)),
0),
new ActualClass("mouse",
15,
Arrays.asList(
new PredictedClass("ant", 2L),
new PredictedClass("cat", 5L),
new PredictedClass("dog", 4L),
new PredictedClass("fox", 3L),
new PredictedClass("mouse", 1L)),
0))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
@ -103,39 +119,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("ant", new HashMap<>());
expectedConfusionMatrix.get("ant").put("ant", 1L);
expectedConfusionMatrix.get("ant").put("cat", 4L);
expectedConfusionMatrix.get("ant").put("dog", 3L);
expectedConfusionMatrix.get("ant").put("fox", 2L);
expectedConfusionMatrix.get("ant").put("mouse", 5L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("ant", 3L);
expectedConfusionMatrix.get("cat").put("cat", 1L);
expectedConfusionMatrix.get("cat").put("dog", 5L);
expectedConfusionMatrix.get("cat").put("fox", 4L);
expectedConfusionMatrix.get("cat").put("mouse", 2L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("ant", 4L);
expectedConfusionMatrix.get("dog").put("cat", 2L);
expectedConfusionMatrix.get("dog").put("dog", 1L);
expectedConfusionMatrix.get("dog").put("fox", 5L);
expectedConfusionMatrix.get("dog").put("mouse", 3L);
expectedConfusionMatrix.put("fox", new HashMap<>());
expectedConfusionMatrix.get("fox").put("ant", 5L);
expectedConfusionMatrix.get("fox").put("cat", 3L);
expectedConfusionMatrix.get("fox").put("dog", 2L);
expectedConfusionMatrix.get("fox").put("fox", 1L);
expectedConfusionMatrix.get("fox").put("mouse", 4L);
expectedConfusionMatrix.put("mouse", new HashMap<>());
expectedConfusionMatrix.get("mouse").put("ant", 2L);
expectedConfusionMatrix.get("mouse").put("cat", 5L);
expectedConfusionMatrix.get("mouse").put("dog", 4L);
expectedConfusionMatrix.get("mouse").put("fox", 3L);
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
assertThat(
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new ActualClass("ant",
15,
Arrays.asList(
new PredictedClass("ant", 1L),
new PredictedClass("cat", 4L),
new PredictedClass("dog", 3L),
new PredictedClass("fox", 2L),
new PredictedClass("mouse", 5L)),
0),
new ActualClass("cat",
15,
Arrays.asList(
new PredictedClass("ant", 3L),
new PredictedClass("cat", 1L),
new PredictedClass("dog", 5L),
new PredictedClass("fox", 4L),
new PredictedClass("mouse", 2L)),
0),
new ActualClass("dog",
15,
Arrays.asList(
new PredictedClass("ant", 4L),
new PredictedClass("cat", 2L),
new PredictedClass("dog", 1L),
new PredictedClass("fox", 5L),
new PredictedClass("mouse", 3L)),
0),
new ActualClass("fox",
15,
Arrays.asList(
new PredictedClass("ant", 5L),
new PredictedClass("cat", 3L),
new PredictedClass("dog", 2L),
new PredictedClass("fox", 1L),
new PredictedClass("mouse", 4L)),
0),
new ActualClass("mouse",
15,
Arrays.asList(
new PredictedClass("ant", 2L),
new PredictedClass("cat", 5L),
new PredictedClass("dog", 4L),
new PredictedClass("fox", 3L),
new PredictedClass("mouse", 1L)),
0))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
@ -153,24 +185,22 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("ant", new HashMap<>());
expectedConfusionMatrix.get("ant").put("ant", 1L);
expectedConfusionMatrix.get("ant").put("cat", 4L);
expectedConfusionMatrix.get("ant").put("dog", 3L);
expectedConfusionMatrix.get("ant").put("_other_", 7L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("ant", 3L);
expectedConfusionMatrix.get("cat").put("cat", 1L);
expectedConfusionMatrix.get("cat").put("dog", 5L);
expectedConfusionMatrix.get("cat").put("_other_", 6L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("ant", 4L);
expectedConfusionMatrix.get("dog").put("cat", 2L);
expectedConfusionMatrix.get("dog").put("dog", 1L);
expectedConfusionMatrix.get("dog").put("_other_", 8L);
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L));
assertThat(
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new ActualClass("ant",
15,
Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)),
7),
new ActualClass("cat",
15,
Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)),
6),
new ActualClass("dog",
15,
Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)),
8))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
}
private static void indexAnimalsData(String indexName) {

View File

@ -618,8 +618,40 @@ setup:
}
}
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
- match: { classification.multiclass_confusion_matrix._other_: 0 }
- match:
classification.multiclass_confusion_matrix:
confusion_matrix:
- actual_class: "cat"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 2
- predicted_class: "dog"
count: 1
- predicted_class: "mouse"
count: 0
other_predicted_class_doc_count: 0
- actual_class: "dog"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 1
- predicted_class: "dog"
count: 2
- predicted_class: "mouse"
count: 0
other_predicted_class_doc_count: 0
- actual_class: "mouse"
actual_class_doc_count: 2
predicted_classes:
- predicted_class: "cat"
count: 1
- predicted_class: "dog"
count: 0
- predicted_class: "mouse"
count: 1
other_predicted_class_doc_count: 0
other_actual_class_count: 0
---
"Test classification multiclass_confusion_matrix with explicit size":
- do:
@ -636,8 +668,26 @@ setup:
}
}
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } }
- match: { classification.multiclass_confusion_matrix._other_: 1 }
- match:
classification.multiclass_confusion_matrix:
confusion_matrix:
- actual_class: "cat"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 2
- predicted_class: "dog"
count: 1
other_predicted_class_doc_count: 0
- actual_class: "dog"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 1
- predicted_class: "dog"
count: 2
other_predicted_class_doc_count: 0
other_actual_class_count: 1
---
"Test classification with null metrics":
- do:
@ -653,8 +703,7 @@ setup:
}
}
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
- match: { classification.multiclass_confusion_matrix._other_: 0 }
- is_true: classification.multiclass_confusion_matrix
---
"Test classification given missing actual_field":
- do: