diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java index a8e8545009b..7199660e94d 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -21,17 +21,17 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Collections; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.TreeMap; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; /** @@ -97,32 +97,28 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric { public static class Result implements EvaluationMetric.Result { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + "multiclass_confusion_matrix_result", true, a -> new Result((List) a[0], (Long) a[1])); static { - PARSER.declareObject( - constructorArg(), - (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), - CONFUSION_MATRIX); - PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); + PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT); } public static Result fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - // Immutable - private final Map> confusionMatrix; - private final long otherClassesCount; + private final List confusionMatrix; + private final Long otherActualClassCount; - public Result(Map> confusionMatrix, long otherClassesCount) { - this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); - this.otherClassesCount = otherClassesCount; + public Result(@Nullable List confusionMatrix, @Nullable Long otherActualClassCount) { + this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null; + this.otherActualClassCount = otherActualClassCount; } @Override @@ -130,19 +126,23 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric { return NAME; } - public Map> getConfusionMatrix() { + public List getConfusionMatrix() { return confusionMatrix; } - public long getOtherClassesCount() { - return otherClassesCount; + public Long getOtherActualClassCount() { + return otherActualClassCount; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); - builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + if (confusionMatrix != null) { + builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + } + if (otherActualClassCount != null) { + builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount); + } builder.endObject(); return builder; } @@ -153,12 +153,140 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric { if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; return Objects.equals(this.confusionMatrix, that.confusionMatrix) - && this.otherClassesCount == that.otherClassesCount; + && Objects.equals(this.otherActualClassCount, that.otherActualClassCount); } @Override public int hashCode() { - return Objects.hash(confusionMatrix, otherClassesCount); + return Objects.hash(confusionMatrix, otherActualClassCount); + } + } + + public static class ActualClass implements ToXContentObject { + + private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); + private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_actual_class", + true, + a -> new ActualClass((String) a[0], (Long) a[1], (List) a[2], (Long) a[3])); + + static { + PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS); + PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); + PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT); + } + + private final String actualClass; + private final Long actualClassDocCount; + private final List predictedClasses; + private final Long otherPredictedClassDocCount; + + public ActualClass(@Nullable String actualClass, + @Nullable Long actualClassDocCount, + @Nullable List predictedClasses, + @Nullable Long otherPredictedClassDocCount) { + this.actualClass = actualClass; + this.actualClassDocCount = actualClassDocCount; + this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null; + this.otherPredictedClassDocCount = otherPredictedClassDocCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (actualClass != null) { + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + } + if (actualClassDocCount != null) { + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + } + if (predictedClasses != null) { + builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); + } + if (otherPredictedClassDocCount != null) { + builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ActualClass that = (ActualClass) o; + return Objects.equals(this.actualClass, that.actualClass) + && Objects.equals(this.actualClassDocCount, that.actualClassDocCount) + && Objects.equals(this.predictedClasses, that.predictedClasses) + && Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount); + } + + @Override + public int hashCode() { + return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class PredictedClass implements ToXContentObject { + + private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class"); + private static final ParseField COUNT = new ParseField("count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1])); + + static { + PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS); + PARSER.declareLong(optionalConstructorArg(), COUNT); + } + + private final String predictedClass; + private final Long count; + + public PredictedClass(@Nullable String predictedClass, @Nullable Long count) { + this.predictedClass = predictedClass; + this.count = count; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (predictedClass != null) { + builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass); + } + if (count != null) { + builder.field(COUNT.getPreferredName(), count); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PredictedClass that = (PredictedClass) o; + return Objects.equals(this.predictedClass, that.predictedClass) + && Objects.equals(this.count, that.count); + } + + @Override + public int hashCode() { + return Objects.hash(predictedClass, count); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 07cfa15b34c..bc42aba9ecd 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -127,6 +127,8 @@ import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; @@ -1807,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog")) - .add(docForClassification(indexName, "horse", "cat")); + .add(docForClassification(indexName, "ant", "cat")); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1827,22 +1829,26 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { MulticlassConfusionMatrixMetric.Result mcmResult = evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("cat", 3L); - expectedConfusionMatrix.get("cat").put("dog", 1L); - expectedConfusionMatrix.get("cat").put("horse", 0L); - expectedConfusionMatrix.get("cat").put("_other_", 1L); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("cat", 1L); - expectedConfusionMatrix.get("dog").put("dog", 3L); - expectedConfusionMatrix.get("dog").put("horse", 0L); - expectedConfusionMatrix.put("horse", new HashMap<>()); - expectedConfusionMatrix.get("horse").put("cat", 1L); - expectedConfusionMatrix.get("horse").put("dog", 0L); - expectedConfusionMatrix.get("horse").put("horse", 0L); - assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(mcmResult.getOtherClassesCount(), equalTo(0L)); + assertThat( + mcmResult.getConfusionMatrix(), + equalTo( + Arrays.asList( + new ActualClass( + "ant", + 1L, + Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), + 0L), + new ActualClass( + "cat", + 5L, + Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), + 1L), + new ActualClass( + "dog", + 4L, + Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), + 0L)))); + assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L)); } { // Explicit size provided for MulticlassConfusionMatrixMetric metric EvaluateDataFrameRequest evaluateDataFrameRequest = @@ -1859,16 +1865,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { MulticlassConfusionMatrixMetric.Result mcmResult = evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("cat", 3L); - expectedConfusionMatrix.get("cat").put("dog", 1L); - expectedConfusionMatrix.get("cat").put("_other_", 1L); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("cat", 1L); - expectedConfusionMatrix.get("dog").put("dog", 3L); - assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(mcmResult.getOtherClassesCount(), equalTo(1L)); + assertThat( + mcmResult.getConfusionMatrix(), + equalTo( + Arrays.asList( + new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L), + new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L) + ))); + assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 9bfe943e2c0..762eaaaf906 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -142,6 +142,8 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -3355,33 +3357,31 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1> - Map> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> - long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3> + List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> + long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3> // end::evaluate-data-frame-results-classification assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat( confusionMatrix, equalTo( - new HashMap>() {{ - put("cat", new HashMap() {{ - put("cat", 3L); - put("dog", 1L); - put("ant", 0L); - put("_other_", 1L); - }}); - put("dog", new HashMap() {{ - put("cat", 1L); - put("dog", 3L); - put("ant", 0L); - }}); - put("ant", new HashMap() {{ - put("cat", 1L); - put("dog", 0L); - put("ant", 0L); - }}); - }})); - assertThat(otherClassesCount, equalTo(0L)); + Arrays.asList( + new ActualClass( + "ant", + 1L, + Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)), + 0L), + new ActualClass( + "cat", + 5L, + Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), + 1L), + new ActualClass( + "dog", + 4L, + Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), + 0L)))); + assertThat(otherActualClassCount, equalTo(0L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java index 800a2cf7b98..55b74eb94ea 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -19,19 +19,21 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.Result; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; +import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.TreeMap; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase { +public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase { @Override protected NamedXContentRegistry xContentRegistry() { @@ -39,26 +41,28 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent } @Override - protected MulticlassConfusionMatrixMetric.Result createTestInstance() { + protected Result createTestInstance() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - Map> confusionMatrix = new TreeMap<>(); + List actualClasses = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { - Map row = new TreeMap<>(); - confusionMatrix.put(classNames.get(i), row); + List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { - if (randomBoolean()) { - row.put(classNames.get(i), randomNonNegativeLong()); - } + predictedClasses.add(new PredictedClass(classNames.get(j), randomBoolean() ? randomNonNegativeLong() : null)); } + actualClasses.add( + new ActualClass( + classNames.get(i), + randomBoolean() ? randomNonNegativeLong() : null, + predictedClasses, + randomBoolean() ? randomNonNegativeLong() : null)); } - long otherClassesCount = randomNonNegativeLong(); - return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount); + return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null); } @Override - protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { - return MulticlassConfusionMatrixMetric.Result.fromXContent(parser); + protected Result doParseInstance(XContentParser parser) throws IOException { + return Result.fromXContent(parser); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java index a8b24a34447..a52afbfc871 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -9,7 +9,9 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilders; @@ -26,14 +28,14 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Arrays; +import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.TreeMap; import java.util.stream.Collectors; +import static java.util.Comparator.comparing; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -112,18 +114,19 @@ public class MulticlassConfusionMatrix implements ClassificationMetric { .size(size)); } if (result == null) { // This is step 2 - KeyedFilter[] keyedFilters = + KeyedFilter[] keyedFiltersActual = + topActualClassNames.stream() + .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className))) + .toArray(KeyedFilter[]::new); + KeyedFilter[] keyedFiltersPredicted = topActualClassNames.stream() .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .toArray(KeyedFilter[]::new); return Arrays.asList( AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) .field(actualField), - AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) - .field(actualField) - .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) - .size(size) - .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters) + AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual) + .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted) .otherBucket(true) .otherBucketKey(OTHER_BUCKET_KEY))); } @@ -134,26 +137,31 @@ public class MulticlassConfusionMatrix implements ClassificationMetric { public void process(Aggregations aggs) { if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); - topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList()); + topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()); } if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); - Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); - Map> counts = new TreeMap<>(); - for (Terms.Bucket bucket : termsAgg.getBuckets()) { + Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); + List actualClasses = new ArrayList<>(filtersAgg.getBuckets().size()); + for (Filters.Bucket bucket : filtersAgg.getBuckets()) { String actualClass = bucket.getKeyAsString(); - Map subCounts = new TreeMap<>(); - counts.put(actualClass, subCounts); + long actualClassDocCount = bucket.getDocCount(); Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); + List predictedClasses = new ArrayList<>(); + long otherPredictedClassDocCount = 0; for (Filters.Bucket subBucket : subAgg.getBuckets()) { String predictedClass = subBucket.getKeyAsString(); - Long docCount = subBucket.getDocCount(); - if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) { - subCounts.put(predictedClass, docCount); + long docCount = subBucket.getDocCount(); + if (OTHER_BUCKET_KEY.equals(predictedClass)) { + otherPredictedClassDocCount = docCount; + } else { + predictedClasses.add(new PredictedClass(predictedClass, docCount)); } } + predictedClasses.sort(comparing(PredictedClass::getPredictedClass)); + actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount)); } - result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size); + result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)); } } @@ -191,37 +199,35 @@ public class MulticlassConfusionMatrix implements ClassificationMetric { public static class Result implements EvaluationMetricResult { private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); - private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count"); + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + "multiclass_confusion_matrix_result", true, a -> new Result((List) a[0], (long) a[1])); static { - PARSER.declareObject( - constructorArg(), - (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), - CONFUSION_MATRIX); - PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX); + PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT); } public static Result fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - // Immutable - private final Map> confusionMatrix; - private final long otherClassesCount; + /** List of actual classes. */ + private final List actualClasses; + /** Number of actual classes that were not included in the confusion matrix because there were too many of them. */ + private final long otherActualClassCount; - public Result(Map> confusionMatrix, long otherClassesCount) { - this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); - this.otherClassesCount = otherClassesCount; + public Result(List actualClasses, long otherActualClassCount) { + this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX)); + this.otherActualClassCount = requireNonNegative(otherActualClassCount, OTHER_ACTUAL_CLASS_COUNT); } public Result(StreamInput in) throws IOException { - this.confusionMatrix = Collections.unmodifiableMap( - in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong))); - this.otherClassesCount = in.readLong(); + this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new)); + this.otherActualClassCount = in.readVLong(); } @Override @@ -234,28 +240,25 @@ public class MulticlassConfusionMatrix implements ClassificationMetric { return NAME.getPreferredName(); } - public Map> getConfusionMatrix() { - return confusionMatrix; + public List getConfusionMatrix() { + return actualClasses; } - public long getOtherClassesCount() { - return otherClassesCount; + public long getOtherActualClassCount() { + return otherActualClassCount; } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeMap( - confusionMatrix, - StreamOutput::writeString, - (out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong)); - out.writeLong(otherClassesCount); + out.writeList(actualClasses); + out.writeVLong(otherActualClassCount); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); - builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.field(CONFUSION_MATRIX.getPreferredName(), actualClasses); + builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount); builder.endObject(); return builder; } @@ -265,13 +268,163 @@ public class MulticlassConfusionMatrix implements ClassificationMetric { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Result that = (Result) o; - return Objects.equals(this.confusionMatrix, that.confusionMatrix) - && this.otherClassesCount == that.otherClassesCount; + return Objects.equals(this.actualClasses, that.actualClasses) + && this.otherActualClassCount == that.otherActualClassCount; } @Override public int hashCode() { - return Objects.hash(confusionMatrix, otherClassesCount); + return Objects.hash(actualClasses, otherActualClassCount); } } + + public static class ActualClass implements ToXContentObject, Writeable { + + private static final ParseField ACTUAL_CLASS = new ParseField("actual_class"); + private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count"); + private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes"); + private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_actual_class", + true, + a -> new ActualClass((String) a[0], (long) a[1], (List) a[2], (long) a[3])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_CLASS); + PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT); + PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES); + PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT); + } + + /** Name of the actual class. */ + private final String actualClass; + /** Number of documents (examples) belonging to the {code actualClass} class. */ + private final long actualClassDocCount; + /** List of predicted classes. */ + private final List predictedClasses; + /** Number of documents that were not predicted as any of the {@code predictedClasses}. */ + private final long otherPredictedClassDocCount; + + public ActualClass( + String actualClass, long actualClassDocCount, List predictedClasses, long otherPredictedClassDocCount) { + this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS); + this.actualClassDocCount = requireNonNegative(actualClassDocCount, ACTUAL_CLASS_DOC_COUNT); + this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES)); + this.otherPredictedClassDocCount = requireNonNegative(otherPredictedClassDocCount, OTHER_PREDICTED_CLASS_DOC_COUNT); + } + + public ActualClass(StreamInput in) throws IOException { + this.actualClass = in.readString(); + this.actualClassDocCount = in.readVLong(); + this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new)); + this.otherPredictedClassDocCount = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualClass); + out.writeVLong(actualClassDocCount); + out.writeList(predictedClasses); + out.writeVLong(otherPredictedClassDocCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_CLASS.getPreferredName(), actualClass); + builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount); + builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses); + builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ActualClass that = (ActualClass) o; + return Objects.equals(this.actualClass, that.actualClass) + && this.actualClassDocCount == that.actualClassDocCount + && Objects.equals(this.predictedClasses, that.predictedClasses) + && this.otherPredictedClassDocCount == that.otherPredictedClassDocCount; + } + + @Override + public int hashCode() { + return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount); + } + } + + public static class PredictedClass implements ToXContentObject, Writeable { + + private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class"); + private static final ParseField COUNT = new ParseField("count"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1])); + + static { + PARSER.declareString(constructorArg(), PREDICTED_CLASS); + PARSER.declareLong(constructorArg(), COUNT); + } + + private final String predictedClass; + private final long count; + + public PredictedClass(String predictedClass, long count) { + this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS); + this.count = requireNonNegative(count, COUNT); + } + + public PredictedClass(StreamInput in) throws IOException { + this.predictedClass = in.readString(); + this.count = in.readVLong(); + } + + public String getPredictedClass() { + return predictedClass; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(predictedClass); + out.writeVLong(count); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass); + builder.field(COUNT.getPreferredName(), count); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PredictedClass that = (PredictedClass) o; + return Objects.equals(this.predictedClass, that.predictedClass) + && this.count == that.count; + } + + @Override + public int hashCode() { + return Objects.hash(predictedClass, count); + } + } + + private static long requireNonNegative(long value, ParseField field) { + if (value < 0) { + throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value); + } + return value; + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java index 24b13d372d5..a2c30eaeb49 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -5,50 +5,53 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.TreeMap; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; -public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { +import static org.hamcrest.Matchers.equalTo; - public static MulticlassConfusionMatrix.Result createRandom() { +public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { + + public static Result createRandom() { int numClasses = randomIntBetween(2, 100); List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); - Map> confusionMatrix = new TreeMap<>(); + List actualClasses = new ArrayList<>(numClasses); for (int i = 0; i < numClasses; i++) { - Map row = new TreeMap<>(); - confusionMatrix.put(classNames.get(i), row); + List predictedClasses = new ArrayList<>(numClasses); for (int j = 0; j < numClasses; j++) { - if (randomBoolean()) { - row.put(classNames.get(i), randomNonNegativeLong()); - } + predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong())); } + actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong())); } - long otherClassesCount = randomNonNegativeLong(); - return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount); + return new Result(actualClasses, randomNonNegativeLong()); } @Override - protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException { - return MulticlassConfusionMatrix.Result.fromXContent(parser); + protected Result doParseInstance(XContentParser parser) throws IOException { + return Result.fromXContent(parser); } @Override - protected MulticlassConfusionMatrix.Result createTestInstance() { + protected Result createTestInstance() { return createRandom(); } @Override - protected Writeable.Reader instanceReader() { - return MulticlassConfusionMatrix.Result::new; + protected Writeable.Reader instanceReader() { + return Result::new; } @Override @@ -61,4 +64,67 @@ public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTes // allow unknown fields in the root of the object only return field -> !field.isEmpty(); } + + public void testConstructor_ValidationFailures() { + { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new Result(null, 0)); + assertThat(e.getMessage(), equalTo("[confusion_matrix] must not be null.")); + } + { + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new Result(Collections.emptyList(), -1)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[other_actual_class_count] must be >= 0, was: -1")); + } + { + IllegalArgumentException e = + expectThrows( + IllegalArgumentException.class, + () -> new Result(Collections.singletonList(new ActualClass(null, 0, Collections.emptyList(), 0)), 0)); + assertThat(e.getMessage(), equalTo("[actual_class] must not be null.")); + } + { + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> new Result(Collections.singletonList(new ActualClass("actual_class", -1, Collections.emptyList(), 0)), 0)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[actual_class_doc_count] must be >= 0, was: -1")); + } + { + IllegalArgumentException e = + expectThrows( + IllegalArgumentException.class, + () -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, null, 0)), 0)); + assertThat(e.getMessage(), equalTo("[predicted_classes] must not be null.")); + } + { + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, Collections.emptyList(), -1)), 0)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[other_predicted_class_doc_count] must be >= 0, was: -1")); + } + { + IllegalArgumentException e = + expectThrows( + IllegalArgumentException.class, + () -> new Result( + Collections.singletonList( + new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass(null, 0)), 0)), + 0)); + assertThat(e.getMessage(), equalTo("[predicted_class] must not be null.")); + } + { + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> new Result( + Collections.singletonList( + new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass("predicted_class", -1)), 0)), + 0)); + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), equalTo("[count] must be >= 0, was: -1")); + } + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java index ff788460b49..0991093c9ee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -14,13 +14,13 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters; import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.metrics.Cardinality; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; import static org.hamcrest.Matchers.empty; @@ -88,22 +88,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 0L), - mockTerms( + mockFilters( "multiclass_confusion_matrix_step_2_by_actual_class", Arrays.asList( - mockTermsBucket( + mockFiltersBucket( "dog", + 30, new Aggregations(Arrays.asList(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", Arrays.asList( mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), - mockTermsBucket( + mockFiltersBucket( "cat", + 70, new Aggregations(Arrays.asList(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", Arrays.asList( - mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))), - 0L), + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))), mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); @@ -112,15 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("cat", 10L); - expectedConfusionMatrix.get("dog").put("dog", 20L); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("cat", 30L); - expectedConfusionMatrix.get("cat").put("dog", 40L); - assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(result.getOtherClassesCount(), equalTo(0L)); + assertThat( + result.getConfusionMatrix(), + equalTo( + Arrays.asList( + new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), + new ActualClass("cat", 70, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0)))); + assertThat(result.getOtherActualClassCount(), equalTo(0L)); } public void testEvaluate_OtherClassesCountGreaterThanZero() { @@ -131,22 +130,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), 100L), - mockTerms( + mockFilters( "multiclass_confusion_matrix_step_2_by_actual_class", Arrays.asList( - mockTermsBucket( + mockFiltersBucket( "dog", + 30, new Aggregations(Arrays.asList(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", Arrays.asList( mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), - mockTermsBucket( + mockFiltersBucket( "cat", + 85, new Aggregations(Arrays.asList(mockFilters( "multiclass_confusion_matrix_step_2_by_predicted_class", Arrays.asList( - mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))), - 100L), + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))), mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); @@ -155,16 +155,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("cat", 10L); - expectedConfusionMatrix.get("dog").put("dog", 20L); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("cat", 30L); - expectedConfusionMatrix.get("cat").put("dog", 40L); - expectedConfusionMatrix.get("cat").put("_other_", 15L); - assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(result.getOtherClassesCount(), equalTo(3L)); + assertThat( + result.getConfusionMatrix(), + equalTo( + Arrays.asList( + new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0), + new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15)))); + assertThat(result.getOtherActualClassCount(), equalTo(3L)); } private static Terms mockTerms(String name, List buckets, long sumOfOtherDocCounts) { @@ -175,9 +172,9 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< return aggregation; } - private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) { + private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) { Terms.Bucket bucket = mock(Terms.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(actualClass); + when(bucket.getKeyAsString()).thenReturn(key); when(bucket.getAggregations()).thenReturn(subAggs); return bucket; } @@ -189,9 +186,15 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase< return aggregation; } - private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) { + private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) { + Filters.Bucket bucket = mockFiltersBucket(key, docCount); + when(bucket.getAggregations()).thenReturn(subAggs); + return bucket; + } + + private static Filters.Bucket mockFiltersBucket(String key, long docCount) { Filters.Bucket bucket = mock(Filters.Bucket.class); - when(bucket.getKeyAsString()).thenReturn(predictedClass); + when(bucket.getKeyAsString()).thenReturn(key); when(bucket.getDocCount()).thenReturn(docCount); return bucket; } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java index ba70828f5c1..299f5e596fb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -12,13 +12,13 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass; import org.junit.After; import org.junit.Before; import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static org.hamcrest.Matchers.equalTo; @@ -53,39 +53,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT MulticlassConfusionMatrix.Result confusionMatrixResult = (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("ant", new HashMap<>()); - expectedConfusionMatrix.get("ant").put("ant", 1L); - expectedConfusionMatrix.get("ant").put("cat", 4L); - expectedConfusionMatrix.get("ant").put("dog", 3L); - expectedConfusionMatrix.get("ant").put("fox", 2L); - expectedConfusionMatrix.get("ant").put("mouse", 5L); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("ant", 3L); - expectedConfusionMatrix.get("cat").put("cat", 1L); - expectedConfusionMatrix.get("cat").put("dog", 5L); - expectedConfusionMatrix.get("cat").put("fox", 4L); - expectedConfusionMatrix.get("cat").put("mouse", 2L); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("ant", 4L); - expectedConfusionMatrix.get("dog").put("cat", 2L); - expectedConfusionMatrix.get("dog").put("dog", 1L); - expectedConfusionMatrix.get("dog").put("fox", 5L); - expectedConfusionMatrix.get("dog").put("mouse", 3L); - expectedConfusionMatrix.put("fox", new HashMap<>()); - expectedConfusionMatrix.get("fox").put("ant", 5L); - expectedConfusionMatrix.get("fox").put("cat", 3L); - expectedConfusionMatrix.get("fox").put("dog", 2L); - expectedConfusionMatrix.get("fox").put("fox", 1L); - expectedConfusionMatrix.get("fox").put("mouse", 4L); - expectedConfusionMatrix.put("mouse", new HashMap<>()); - expectedConfusionMatrix.get("mouse").put("ant", 2L); - expectedConfusionMatrix.get("mouse").put("cat", 5L); - expectedConfusionMatrix.get("mouse").put("dog", 4L); - expectedConfusionMatrix.get("mouse").put("fox", 3L); - expectedConfusionMatrix.get("mouse").put("mouse", 1L); - assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + assertThat( + confusionMatrixResult.getConfusionMatrix(), + equalTo(Arrays.asList( + new ActualClass("ant", + 15, + Arrays.asList( + new PredictedClass("ant", 1L), + new PredictedClass("cat", 4L), + new PredictedClass("dog", 3L), + new PredictedClass("fox", 2L), + new PredictedClass("mouse", 5L)), + 0), + new ActualClass("cat", + 15, + Arrays.asList( + new PredictedClass("ant", 3L), + new PredictedClass("cat", 1L), + new PredictedClass("dog", 5L), + new PredictedClass("fox", 4L), + new PredictedClass("mouse", 2L)), + 0), + new ActualClass("dog", + 15, + Arrays.asList( + new PredictedClass("ant", 4L), + new PredictedClass("cat", 2L), + new PredictedClass("dog", 1L), + new PredictedClass("fox", 5L), + new PredictedClass("mouse", 3L)), + 0), + new ActualClass("fox", + 15, + Arrays.asList( + new PredictedClass("ant", 5L), + new PredictedClass("cat", 3L), + new PredictedClass("dog", 2L), + new PredictedClass("fox", 1L), + new PredictedClass("mouse", 4L)), + 0), + new ActualClass("mouse", + 15, + Arrays.asList( + new PredictedClass("ant", 2L), + new PredictedClass("cat", 5L), + new PredictedClass("dog", 4L), + new PredictedClass("fox", 3L), + new PredictedClass("mouse", 1L)), + 0)))); + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { @@ -103,39 +119,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT MulticlassConfusionMatrix.Result confusionMatrixResult = (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("ant", new HashMap<>()); - expectedConfusionMatrix.get("ant").put("ant", 1L); - expectedConfusionMatrix.get("ant").put("cat", 4L); - expectedConfusionMatrix.get("ant").put("dog", 3L); - expectedConfusionMatrix.get("ant").put("fox", 2L); - expectedConfusionMatrix.get("ant").put("mouse", 5L); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("ant", 3L); - expectedConfusionMatrix.get("cat").put("cat", 1L); - expectedConfusionMatrix.get("cat").put("dog", 5L); - expectedConfusionMatrix.get("cat").put("fox", 4L); - expectedConfusionMatrix.get("cat").put("mouse", 2L); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("ant", 4L); - expectedConfusionMatrix.get("dog").put("cat", 2L); - expectedConfusionMatrix.get("dog").put("dog", 1L); - expectedConfusionMatrix.get("dog").put("fox", 5L); - expectedConfusionMatrix.get("dog").put("mouse", 3L); - expectedConfusionMatrix.put("fox", new HashMap<>()); - expectedConfusionMatrix.get("fox").put("ant", 5L); - expectedConfusionMatrix.get("fox").put("cat", 3L); - expectedConfusionMatrix.get("fox").put("dog", 2L); - expectedConfusionMatrix.get("fox").put("fox", 1L); - expectedConfusionMatrix.get("fox").put("mouse", 4L); - expectedConfusionMatrix.put("mouse", new HashMap<>()); - expectedConfusionMatrix.get("mouse").put("ant", 2L); - expectedConfusionMatrix.get("mouse").put("cat", 5L); - expectedConfusionMatrix.get("mouse").put("dog", 4L); - expectedConfusionMatrix.get("mouse").put("fox", 3L); - expectedConfusionMatrix.get("mouse").put("mouse", 1L); - assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + assertThat( + confusionMatrixResult.getConfusionMatrix(), + equalTo(Arrays.asList( + new ActualClass("ant", + 15, + Arrays.asList( + new PredictedClass("ant", 1L), + new PredictedClass("cat", 4L), + new PredictedClass("dog", 3L), + new PredictedClass("fox", 2L), + new PredictedClass("mouse", 5L)), + 0), + new ActualClass("cat", + 15, + Arrays.asList( + new PredictedClass("ant", 3L), + new PredictedClass("cat", 1L), + new PredictedClass("dog", 5L), + new PredictedClass("fox", 4L), + new PredictedClass("mouse", 2L)), + 0), + new ActualClass("dog", + 15, + Arrays.asList( + new PredictedClass("ant", 4L), + new PredictedClass("cat", 2L), + new PredictedClass("dog", 1L), + new PredictedClass("fox", 5L), + new PredictedClass("mouse", 3L)), + 0), + new ActualClass("fox", + 15, + Arrays.asList( + new PredictedClass("ant", 5L), + new PredictedClass("cat", 3L), + new PredictedClass("dog", 2L), + new PredictedClass("fox", 1L), + new PredictedClass("mouse", 4L)), + 0), + new ActualClass("mouse", + 15, + Arrays.asList( + new PredictedClass("ant", 2L), + new PredictedClass("cat", 5L), + new PredictedClass("dog", 4L), + new PredictedClass("fox", 3L), + new PredictedClass("mouse", 1L)), + 0)))); + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { @@ -153,24 +185,22 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT MulticlassConfusionMatrix.Result confusionMatrixResult = (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); - Map> expectedConfusionMatrix = new HashMap<>(); - expectedConfusionMatrix.put("ant", new HashMap<>()); - expectedConfusionMatrix.get("ant").put("ant", 1L); - expectedConfusionMatrix.get("ant").put("cat", 4L); - expectedConfusionMatrix.get("ant").put("dog", 3L); - expectedConfusionMatrix.get("ant").put("_other_", 7L); - expectedConfusionMatrix.put("cat", new HashMap<>()); - expectedConfusionMatrix.get("cat").put("ant", 3L); - expectedConfusionMatrix.get("cat").put("cat", 1L); - expectedConfusionMatrix.get("cat").put("dog", 5L); - expectedConfusionMatrix.get("cat").put("_other_", 6L); - expectedConfusionMatrix.put("dog", new HashMap<>()); - expectedConfusionMatrix.get("dog").put("ant", 4L); - expectedConfusionMatrix.get("dog").put("cat", 2L); - expectedConfusionMatrix.get("dog").put("dog", 1L); - expectedConfusionMatrix.get("dog").put("_other_", 8L); - assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); - assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L)); + assertThat( + confusionMatrixResult.getConfusionMatrix(), + equalTo(Arrays.asList( + new ActualClass("ant", + 15, + Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)), + 7), + new ActualClass("cat", + 15, + Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)), + 6), + new ActualClass("dog", + 15, + Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)), + 8)))); + assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L)); } private static void indexAnimalsData(String indexName) { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 1bcde11f2fb..f35346fc785 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -618,8 +618,40 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } - - match: { classification.multiclass_confusion_matrix._other_: 0 } + - match: + classification.multiclass_confusion_matrix: + confusion_matrix: + - actual_class: "cat" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 2 + - predicted_class: "dog" + count: 1 + - predicted_class: "mouse" + count: 0 + other_predicted_class_doc_count: 0 + - actual_class: "dog" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 1 + - predicted_class: "dog" + count: 2 + - predicted_class: "mouse" + count: 0 + other_predicted_class_doc_count: 0 + - actual_class: "mouse" + actual_class_doc_count: 2 + predicted_classes: + - predicted_class: "cat" + count: 1 + - predicted_class: "dog" + count: 0 + - predicted_class: "mouse" + count: 1 + other_predicted_class_doc_count: 0 + other_actual_class_count: 0 --- "Test classification multiclass_confusion_matrix with explicit size": - do: @@ -636,8 +668,26 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } } - - match: { classification.multiclass_confusion_matrix._other_: 1 } + - match: + classification.multiclass_confusion_matrix: + confusion_matrix: + - actual_class: "cat" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 2 + - predicted_class: "dog" + count: 1 + other_predicted_class_doc_count: 0 + - actual_class: "dog" + actual_class_doc_count: 3 + predicted_classes: + - predicted_class: "cat" + count: 1 + - predicted_class: "dog" + count: 2 + other_predicted_class_doc_count: 0 + other_actual_class_count: 1 --- "Test classification with null metrics": - do: @@ -653,8 +703,7 @@ setup: } } - - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } - - match: { classification.multiclass_confusion_matrix._other_: 0 } + - is_true: classification.multiclass_confusion_matrix --- "Test classification given missing actual_field": - do: