[7.x] Change format of MulticlassConfusionMatrix result to be more self-explanatory (#48174) (#48294)

This commit is contained in:
Przemysław Witek 2019-10-21 22:07:19 +02:00 committed by GitHub
parent 178204703a
commit 2db2b945ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 714 additions and 277 deletions

View File

@ -21,17 +21,17 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.TreeMap;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/** /**
@ -97,32 +97,28 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
public static class Result implements EvaluationMetric.Result { public static class Result implements EvaluationMetric.Result {
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>( new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1])); "multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));
static { static {
PARSER.declareObject( PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
constructorArg(), PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
} }
public static Result fromXContent(XContentParser parser) { public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
// Immutable private final List<ActualClass> confusionMatrix;
private final Map<String, Map<String, Long>> confusionMatrix; private final Long otherActualClassCount;
private final long otherClassesCount;
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) { public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
this.otherClassesCount = otherClassesCount; this.otherActualClassCount = otherActualClassCount;
} }
@Override @Override
@ -130,19 +126,23 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
return NAME; return NAME;
} }
public Map<String, Map<String, Long>> getConfusionMatrix() { public List<ActualClass> getConfusionMatrix() {
return confusionMatrix; return confusionMatrix;
} }
public long getOtherClassesCount() { public Long getOtherActualClassCount() {
return otherClassesCount; return otherActualClassCount;
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
if (confusionMatrix != null) {
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); }
if (otherActualClassCount != null) {
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -153,12 +153,140 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix) return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount; && Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount); return Objects.hash(confusionMatrix, otherActualClassCount);
}
}
public static class ActualClass implements ToXContentObject {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_actual_class",
true,
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));
static {
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
}
private final String actualClass;
private final Long actualClassDocCount;
private final List<PredictedClass> predictedClasses;
private final Long otherPredictedClassDocCount;
public ActualClass(@Nullable String actualClass,
@Nullable Long actualClassDocCount,
@Nullable List<PredictedClass> predictedClasses,
@Nullable Long otherPredictedClassDocCount) {
this.actualClass = actualClass;
this.actualClassDocCount = actualClassDocCount;
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (actualClass != null) {
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
}
if (actualClassDocCount != null) {
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
}
if (predictedClasses != null) {
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
}
if (otherPredictedClassDocCount != null) {
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
&& Objects.equals(this.predictedClasses, that.predictedClasses)
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
public static class PredictedClass implements ToXContentObject {
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
private static final ParseField COUNT = new ParseField("count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));
static {
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
PARSER.declareLong(optionalConstructorArg(), COUNT);
}
private final String predictedClass;
private final Long count;
public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
this.predictedClass = predictedClass;
this.count = count;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (predictedClass != null) {
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
}
if (count != null) {
builder.field(COUNT.getPreferredName(), count);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PredictedClass that = (PredictedClass) o;
return Objects.equals(this.predictedClass, that.predictedClass)
&& Objects.equals(this.count, that.count);
}
@Override
public int hashCode() {
return Objects.hash(predictedClass, count);
} }
} }
} }

View File

@ -127,6 +127,8 @@ import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
@ -1807,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "horse", "cat")); .add(docForClassification(indexName, "ant", "cat"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@ -1827,22 +1829,26 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MulticlassConfusionMatrixMetric.Result mcmResult = MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("cat", new HashMap<>()); mcmResult.getConfusionMatrix(),
expectedConfusionMatrix.get("cat").put("cat", 3L); equalTo(
expectedConfusionMatrix.get("cat").put("dog", 1L); Arrays.asList(
expectedConfusionMatrix.get("cat").put("horse", 0L); new ActualClass(
expectedConfusionMatrix.get("cat").put("_other_", 1L); "ant",
expectedConfusionMatrix.put("dog", new HashMap<>()); 1L,
expectedConfusionMatrix.get("dog").put("cat", 1L); Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
expectedConfusionMatrix.get("dog").put("dog", 3L); 0L),
expectedConfusionMatrix.get("dog").put("horse", 0L); new ActualClass(
expectedConfusionMatrix.put("horse", new HashMap<>()); "cat",
expectedConfusionMatrix.get("horse").put("cat", 1L); 5L,
expectedConfusionMatrix.get("horse").put("dog", 0L); Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
expectedConfusionMatrix.get("horse").put("horse", 0L); 1L),
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); new ActualClass(
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L)); "dog",
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
} }
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric { // Explicit size provided for MulticlassConfusionMatrixMetric metric
EvaluateDataFrameRequest evaluateDataFrameRequest = EvaluateDataFrameRequest evaluateDataFrameRequest =
@ -1859,16 +1865,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MulticlassConfusionMatrixMetric.Result mcmResult = MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("cat", new HashMap<>()); mcmResult.getConfusionMatrix(),
expectedConfusionMatrix.get("cat").put("cat", 3L); equalTo(
expectedConfusionMatrix.get("cat").put("dog", 1L); Arrays.asList(
expectedConfusionMatrix.get("cat").put("_other_", 1L); new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
expectedConfusionMatrix.put("dog", new HashMap<>()); new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
expectedConfusionMatrix.get("dog").put("cat", 1L); )));
expectedConfusionMatrix.get("dog").put("dog", 3L); assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
} }
} }

View File

@ -142,6 +142,8 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -3355,33 +3357,31 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1> response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2> List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3> long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
// end::evaluate-data-frame-results-classification // end::evaluate-data-frame-results-classification
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat( assertThat(
confusionMatrix, confusionMatrix,
equalTo( equalTo(
new HashMap<String, Map<String, Long>>() {{ Arrays.asList(
put("cat", new HashMap<String, Long>() {{ new ActualClass(
put("cat", 3L); "ant",
put("dog", 1L); 1L,
put("ant", 0L); Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
put("_other_", 1L); 0L),
}}); new ActualClass(
put("dog", new HashMap<String, Long>() {{ "cat",
put("cat", 1L); 5L,
put("dog", 3L); Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
put("ant", 0L); 1L),
}}); new ActualClass(
put("ant", new HashMap<String, Long>() {{ "dog",
put("cat", 1L); 4L,
put("dog", 0L); Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
put("ant", 0L); 0L))));
}}); assertThat(otherActualClassCount, equalTo(0L));
}}));
assertThat(otherClassesCount, equalTo(0L));
} }
} }

View File

@ -19,19 +19,21 @@
package org.elasticsearch.client.ml.dataframe.evaluation.classification; package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric.Result> { public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<Result> {
@Override @Override
protected NamedXContentRegistry xContentRegistry() { protected NamedXContentRegistry xContentRegistry() {
@ -39,26 +41,28 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
} }
@Override @Override
protected MulticlassConfusionMatrixMetric.Result createTestInstance() { protected Result createTestInstance() {
int numClasses = randomIntBetween(2, 100); int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>(); List<ActualClass> actualClasses = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) { for (int i = 0; i < numClasses; i++) {
Map<String, Long> row = new TreeMap<>(); List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
confusionMatrix.put(classNames.get(i), row);
for (int j = 0; j < numClasses; j++) { for (int j = 0; j < numClasses; j++) {
if (randomBoolean()) { predictedClasses.add(new PredictedClass(classNames.get(j), randomBoolean() ? randomNonNegativeLong() : null));
row.put(classNames.get(i), randomNonNegativeLong());
} }
actualClasses.add(
new ActualClass(
classNames.get(i),
randomBoolean() ? randomNonNegativeLong() : null,
predictedClasses,
randomBoolean() ? randomNonNegativeLong() : null));
} }
} return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
long otherClassesCount = randomNonNegativeLong();
return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount);
} }
@Override @Override
protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { protected Result doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrixMetric.Result.fromXContent(parser); return Result.fromXContent(parser);
} }
@Override @Override

View File

@ -9,7 +9,9 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
@ -26,14 +28,14 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static java.util.Comparator.comparing;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -112,18 +114,19 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
.size(size)); .size(size));
} }
if (result == null) { // This is step 2 if (result == null) { // This is step 2
KeyedFilter[] keyedFilters = KeyedFilter[] keyedFiltersActual =
topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
.toArray(KeyedFilter[]::new);
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream() topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new); .toArray(KeyedFilter[]::new);
return Arrays.asList( return Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
.field(actualField), .field(actualField),
AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
.field(actualField) .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters)
.otherBucket(true) .otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))); .otherBucketKey(OTHER_BUCKET_KEY)));
} }
@ -134,26 +137,31 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
public void process(Aggregations aggs) { public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList()); topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
} }
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
Map<String, Map<String, Long>> counts = new TreeMap<>(); List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
for (Terms.Bucket bucket : termsAgg.getBuckets()) { for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString(); String actualClass = bucket.getKeyAsString();
Map<String, Long> subCounts = new TreeMap<>(); long actualClassDocCount = bucket.getDocCount();
counts.put(actualClass, subCounts);
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
List<PredictedClass> predictedClasses = new ArrayList<>();
long otherPredictedClassDocCount = 0;
for (Filters.Bucket subBucket : subAgg.getBuckets()) { for (Filters.Bucket subBucket : subAgg.getBuckets()) {
String predictedClass = subBucket.getKeyAsString(); String predictedClass = subBucket.getKeyAsString();
Long docCount = subBucket.getDocCount(); long docCount = subBucket.getDocCount();
if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) { if (OTHER_BUCKET_KEY.equals(predictedClass)) {
subCounts.put(predictedClass, docCount); otherPredictedClassDocCount = docCount;
} else {
predictedClasses.add(new PredictedClass(predictedClass, docCount));
} }
} }
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
} }
result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size); result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
} }
} }
@ -191,37 +199,35 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
public static class Result implements EvaluationMetricResult { public static class Result implements EvaluationMetricResult {
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>( new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1])); "multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (long) a[1]));
static { static {
PARSER.declareObject( PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
constructorArg(), PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT);
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
} }
public static Result fromXContent(XContentParser parser) { public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
// Immutable /** List of actual classes. */
private final Map<String, Map<String, Long>> confusionMatrix; private final List<ActualClass> actualClasses;
private final long otherClassesCount; /** Number of actual classes that were not included in the confusion matrix because there were too many of them. */
private final long otherActualClassCount;
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) { public Result(List<ActualClass> actualClasses, long otherActualClassCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX));
this.otherClassesCount = otherClassesCount; this.otherActualClassCount = requireNonNegative(otherActualClassCount, OTHER_ACTUAL_CLASS_COUNT);
} }
public Result(StreamInput in) throws IOException { public Result(StreamInput in) throws IOException {
this.confusionMatrix = Collections.unmodifiableMap( this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong))); this.otherActualClassCount = in.readVLong();
this.otherClassesCount = in.readLong();
} }
@Override @Override
@ -234,28 +240,25 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
return NAME.getPreferredName(); return NAME.getPreferredName();
} }
public Map<String, Map<String, Long>> getConfusionMatrix() { public List<ActualClass> getConfusionMatrix() {
return confusionMatrix; return actualClasses;
} }
public long getOtherClassesCount() { public long getOtherActualClassCount() {
return otherClassesCount; return otherActualClassCount;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeMap( out.writeList(actualClasses);
confusionMatrix, out.writeVLong(otherActualClassCount);
StreamOutput::writeString,
(out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong));
out.writeLong(otherClassesCount);
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); builder.field(CONFUSION_MATRIX.getPreferredName(), actualClasses);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -265,13 +268,163 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix) return Objects.equals(this.actualClasses, that.actualClasses)
&& this.otherClassesCount == that.otherClassesCount; && this.otherActualClassCount == that.otherActualClassCount;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount); return Objects.hash(actualClasses, otherActualClassCount);
} }
} }
public static class ActualClass implements ToXContentObject, Writeable {
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_actual_class",
true,
a -> new ActualClass((String) a[0], (long) a[1], (List<PredictedClass>) a[2], (long) a[3]));
static {
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
}
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** List of predicted classes. */
private final List<PredictedClass> predictedClasses;
/** Number of documents that were not predicted as any of the {@code predictedClasses}. */
private final long otherPredictedClassDocCount;
public ActualClass(
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
this.actualClassDocCount = requireNonNegative(actualClassDocCount, ACTUAL_CLASS_DOC_COUNT);
this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES));
this.otherPredictedClassDocCount = requireNonNegative(otherPredictedClassDocCount, OTHER_PREDICTED_CLASS_DOC_COUNT);
}
public ActualClass(StreamInput in) throws IOException {
this.actualClass = in.readString();
this.actualClassDocCount = in.readVLong();
this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new));
this.otherPredictedClassDocCount = in.readVLong();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualClass);
out.writeVLong(actualClassDocCount);
out.writeList(predictedClasses);
out.writeVLong(otherPredictedClassDocCount);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ActualClass that = (ActualClass) o;
return Objects.equals(this.actualClass, that.actualClass)
&& this.actualClassDocCount == that.actualClassDocCount
&& Objects.equals(this.predictedClasses, that.predictedClasses)
&& this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
}
@Override
public int hashCode() {
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
}
}
public static class PredictedClass implements ToXContentObject, Writeable {
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
private static final ParseField COUNT = new ParseField("count");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1]));
static {
PARSER.declareString(constructorArg(), PREDICTED_CLASS);
PARSER.declareLong(constructorArg(), COUNT);
}
private final String predictedClass;
private final long count;
public PredictedClass(String predictedClass, long count) {
this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS);
this.count = requireNonNegative(count, COUNT);
}
public PredictedClass(StreamInput in) throws IOException {
this.predictedClass = in.readString();
this.count = in.readVLong();
}
public String getPredictedClass() {
return predictedClass;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(predictedClass);
out.writeVLong(count);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
builder.field(COUNT.getPreferredName(), count);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PredictedClass that = (PredictedClass) o;
return Objects.equals(this.predictedClass, that.predictedClass)
&& this.count == that.count;
}
@Override
public int hashCode() {
return Objects.hash(predictedClass, count);
}
}
private static long requireNonNegative(long value, ParseField field) {
if (value < 0) {
throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value);
}
return value;
}
} }

View File

@ -5,50 +5,53 @@
*/ */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix.Result> { import static org.hamcrest.Matchers.equalTo;
public static MulticlassConfusionMatrix.Result createRandom() { public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<Result> {
public static Result createRandom() {
int numClasses = randomIntBetween(2, 100); int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>(); List<ActualClass> actualClasses = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) { for (int i = 0; i < numClasses; i++) {
Map<String, Long> row = new TreeMap<>(); List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
confusionMatrix.put(classNames.get(i), row);
for (int j = 0; j < numClasses; j++) { for (int j = 0; j < numClasses; j++) {
if (randomBoolean()) { predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
row.put(classNames.get(i), randomNonNegativeLong());
} }
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong()));
} }
} return new Result(actualClasses, randomNonNegativeLong());
long otherClassesCount = randomNonNegativeLong();
return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount);
} }
@Override @Override
protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException { protected Result doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrix.Result.fromXContent(parser); return Result.fromXContent(parser);
} }
@Override @Override
protected MulticlassConfusionMatrix.Result createTestInstance() { protected Result createTestInstance() {
return createRandom(); return createRandom();
} }
@Override @Override
protected Writeable.Reader<MulticlassConfusionMatrix.Result> instanceReader() { protected Writeable.Reader<Result> instanceReader() {
return MulticlassConfusionMatrix.Result::new; return Result::new;
} }
@Override @Override
@ -61,4 +64,67 @@ public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTes
// allow unknown fields in the root of the object only // allow unknown fields in the root of the object only
return field -> !field.isEmpty(); return field -> !field.isEmpty();
} }
public void testConstructor_ValidationFailures() {
{
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new Result(null, 0));
assertThat(e.getMessage(), equalTo("[confusion_matrix] must not be null."));
}
{
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new Result(Collections.emptyList(), -1));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[other_actual_class_count] must be >= 0, was: -1"));
}
{
IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class,
() -> new Result(Collections.singletonList(new ActualClass(null, 0, Collections.emptyList(), 0)), 0));
assertThat(e.getMessage(), equalTo("[actual_class] must not be null."));
}
{
ElasticsearchException e =
expectThrows(
ElasticsearchException.class,
() -> new Result(Collections.singletonList(new ActualClass("actual_class", -1, Collections.emptyList(), 0)), 0));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[actual_class_doc_count] must be >= 0, was: -1"));
}
{
IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class,
() -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, null, 0)), 0));
assertThat(e.getMessage(), equalTo("[predicted_classes] must not be null."));
}
{
ElasticsearchException e =
expectThrows(
ElasticsearchException.class,
() -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, Collections.emptyList(), -1)), 0));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[other_predicted_class_doc_count] must be >= 0, was: -1"));
}
{
IllegalArgumentException e =
expectThrows(
IllegalArgumentException.class,
() -> new Result(
Collections.singletonList(
new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass(null, 0)), 0)),
0));
assertThat(e.getMessage(), equalTo("[predicted_class] must not be null."));
}
{
ElasticsearchException e =
expectThrows(
ElasticsearchException.class,
() -> new Result(
Collections.singletonList(
new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass("predicted_class", -1)), 0)),
0));
assertThat(e.status().getStatus(), equalTo(500));
assertThat(e.getMessage(), equalTo("[count] must be >= 0, was: -1"));
}
}
} }

View File

@ -14,13 +14,13 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality; import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
@ -88,22 +88,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
0L), 0L),
mockTerms( mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class", "multiclass_confusion_matrix_step_2_by_actual_class",
Arrays.asList( Arrays.asList(
mockTermsBucket( mockFiltersBucket(
"dog", "dog",
30,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", "multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockTermsBucket( mockFiltersBucket(
"cat", "cat",
70,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", "multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
0L),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
@ -112,15 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("dog", new HashMap<>()); result.getConfusionMatrix(),
expectedConfusionMatrix.get("dog").put("cat", 10L); equalTo(
expectedConfusionMatrix.get("dog").put("dog", 20L); Arrays.asList(
expectedConfusionMatrix.put("cat", new HashMap<>()); new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0),
expectedConfusionMatrix.get("cat").put("cat", 30L); new ActualClass("cat", 70, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0))));
expectedConfusionMatrix.get("cat").put("dog", 40L); assertThat(result.getOtherActualClassCount(), equalTo(0L));
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(result.getOtherClassesCount(), equalTo(0L));
} }
public void testEvaluate_OtherClassesCountGreaterThanZero() { public void testEvaluate_OtherClassesCountGreaterThanZero() {
@ -131,22 +130,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
mockTermsBucket("dog", new Aggregations(Collections.emptyList())), mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L), 100L),
mockTerms( mockFilters(
"multiclass_confusion_matrix_step_2_by_actual_class", "multiclass_confusion_matrix_step_2_by_actual_class",
Arrays.asList( Arrays.asList(
mockTermsBucket( mockFiltersBucket(
"dog", "dog",
30,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", "multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockTermsBucket( mockFiltersBucket(
"cat", "cat",
85,
new Aggregations(Arrays.asList(mockFilters( new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class", "multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList( Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))), mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
100L),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
@ -155,16 +155,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("dog", new HashMap<>()); result.getConfusionMatrix(),
expectedConfusionMatrix.get("dog").put("cat", 10L); equalTo(
expectedConfusionMatrix.get("dog").put("dog", 20L); Arrays.asList(
expectedConfusionMatrix.put("cat", new HashMap<>()); new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0),
expectedConfusionMatrix.get("cat").put("cat", 30L); new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
expectedConfusionMatrix.get("cat").put("dog", 40L); assertThat(result.getOtherActualClassCount(), equalTo(3L));
expectedConfusionMatrix.get("cat").put("_other_", 15L);
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(result.getOtherClassesCount(), equalTo(3L));
} }
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) { private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
@ -175,9 +172,9 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
return aggregation; return aggregation;
} }
private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) { private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class); Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(actualClass); when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getAggregations()).thenReturn(subAggs); when(bucket.getAggregations()).thenReturn(subAggs);
return bucket; return bucket;
} }
@ -189,9 +186,15 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
return aggregation; return aggregation;
} }
private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) { private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters.Bucket mockFiltersBucket(String key, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class); Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(predictedClass); when(bucket.getKeyAsString()).thenReturn(key);
when(bucket.getDocCount()).thenReturn(docCount); when(bucket.getDocCount()).thenReturn(docCount);
return bucket; return bucket;
} }

View File

@ -12,13 +12,13 @@ import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -53,39 +53,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
MulticlassConfusionMatrix.Result confusionMatrixResult = MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("ant", new HashMap<>()); confusionMatrixResult.getConfusionMatrix(),
expectedConfusionMatrix.get("ant").put("ant", 1L); equalTo(Arrays.asList(
expectedConfusionMatrix.get("ant").put("cat", 4L); new ActualClass("ant",
expectedConfusionMatrix.get("ant").put("dog", 3L); 15,
expectedConfusionMatrix.get("ant").put("fox", 2L); Arrays.asList(
expectedConfusionMatrix.get("ant").put("mouse", 5L); new PredictedClass("ant", 1L),
expectedConfusionMatrix.put("cat", new HashMap<>()); new PredictedClass("cat", 4L),
expectedConfusionMatrix.get("cat").put("ant", 3L); new PredictedClass("dog", 3L),
expectedConfusionMatrix.get("cat").put("cat", 1L); new PredictedClass("fox", 2L),
expectedConfusionMatrix.get("cat").put("dog", 5L); new PredictedClass("mouse", 5L)),
expectedConfusionMatrix.get("cat").put("fox", 4L); 0),
expectedConfusionMatrix.get("cat").put("mouse", 2L); new ActualClass("cat",
expectedConfusionMatrix.put("dog", new HashMap<>()); 15,
expectedConfusionMatrix.get("dog").put("ant", 4L); Arrays.asList(
expectedConfusionMatrix.get("dog").put("cat", 2L); new PredictedClass("ant", 3L),
expectedConfusionMatrix.get("dog").put("dog", 1L); new PredictedClass("cat", 1L),
expectedConfusionMatrix.get("dog").put("fox", 5L); new PredictedClass("dog", 5L),
expectedConfusionMatrix.get("dog").put("mouse", 3L); new PredictedClass("fox", 4L),
expectedConfusionMatrix.put("fox", new HashMap<>()); new PredictedClass("mouse", 2L)),
expectedConfusionMatrix.get("fox").put("ant", 5L); 0),
expectedConfusionMatrix.get("fox").put("cat", 3L); new ActualClass("dog",
expectedConfusionMatrix.get("fox").put("dog", 2L); 15,
expectedConfusionMatrix.get("fox").put("fox", 1L); Arrays.asList(
expectedConfusionMatrix.get("fox").put("mouse", 4L); new PredictedClass("ant", 4L),
expectedConfusionMatrix.put("mouse", new HashMap<>()); new PredictedClass("cat", 2L),
expectedConfusionMatrix.get("mouse").put("ant", 2L); new PredictedClass("dog", 1L),
expectedConfusionMatrix.get("mouse").put("cat", 5L); new PredictedClass("fox", 5L),
expectedConfusionMatrix.get("mouse").put("dog", 4L); new PredictedClass("mouse", 3L)),
expectedConfusionMatrix.get("mouse").put("fox", 3L); 0),
expectedConfusionMatrix.get("mouse").put("mouse", 1L); new ActualClass("fox",
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); 15,
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); Arrays.asList(
new PredictedClass("ant", 5L),
new PredictedClass("cat", 3L),
new PredictedClass("dog", 2L),
new PredictedClass("fox", 1L),
new PredictedClass("mouse", 4L)),
0),
new ActualClass("mouse",
15,
Arrays.asList(
new PredictedClass("ant", 2L),
new PredictedClass("cat", 5L),
new PredictedClass("dog", 4L),
new PredictedClass("fox", 3L),
new PredictedClass("mouse", 1L)),
0))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
} }
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
@ -103,39 +119,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
MulticlassConfusionMatrix.Result confusionMatrixResult = MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("ant", new HashMap<>()); confusionMatrixResult.getConfusionMatrix(),
expectedConfusionMatrix.get("ant").put("ant", 1L); equalTo(Arrays.asList(
expectedConfusionMatrix.get("ant").put("cat", 4L); new ActualClass("ant",
expectedConfusionMatrix.get("ant").put("dog", 3L); 15,
expectedConfusionMatrix.get("ant").put("fox", 2L); Arrays.asList(
expectedConfusionMatrix.get("ant").put("mouse", 5L); new PredictedClass("ant", 1L),
expectedConfusionMatrix.put("cat", new HashMap<>()); new PredictedClass("cat", 4L),
expectedConfusionMatrix.get("cat").put("ant", 3L); new PredictedClass("dog", 3L),
expectedConfusionMatrix.get("cat").put("cat", 1L); new PredictedClass("fox", 2L),
expectedConfusionMatrix.get("cat").put("dog", 5L); new PredictedClass("mouse", 5L)),
expectedConfusionMatrix.get("cat").put("fox", 4L); 0),
expectedConfusionMatrix.get("cat").put("mouse", 2L); new ActualClass("cat",
expectedConfusionMatrix.put("dog", new HashMap<>()); 15,
expectedConfusionMatrix.get("dog").put("ant", 4L); Arrays.asList(
expectedConfusionMatrix.get("dog").put("cat", 2L); new PredictedClass("ant", 3L),
expectedConfusionMatrix.get("dog").put("dog", 1L); new PredictedClass("cat", 1L),
expectedConfusionMatrix.get("dog").put("fox", 5L); new PredictedClass("dog", 5L),
expectedConfusionMatrix.get("dog").put("mouse", 3L); new PredictedClass("fox", 4L),
expectedConfusionMatrix.put("fox", new HashMap<>()); new PredictedClass("mouse", 2L)),
expectedConfusionMatrix.get("fox").put("ant", 5L); 0),
expectedConfusionMatrix.get("fox").put("cat", 3L); new ActualClass("dog",
expectedConfusionMatrix.get("fox").put("dog", 2L); 15,
expectedConfusionMatrix.get("fox").put("fox", 1L); Arrays.asList(
expectedConfusionMatrix.get("fox").put("mouse", 4L); new PredictedClass("ant", 4L),
expectedConfusionMatrix.put("mouse", new HashMap<>()); new PredictedClass("cat", 2L),
expectedConfusionMatrix.get("mouse").put("ant", 2L); new PredictedClass("dog", 1L),
expectedConfusionMatrix.get("mouse").put("cat", 5L); new PredictedClass("fox", 5L),
expectedConfusionMatrix.get("mouse").put("dog", 4L); new PredictedClass("mouse", 3L)),
expectedConfusionMatrix.get("mouse").put("fox", 3L); 0),
expectedConfusionMatrix.get("mouse").put("mouse", 1L); new ActualClass("fox",
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); 15,
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); Arrays.asList(
new PredictedClass("ant", 5L),
new PredictedClass("cat", 3L),
new PredictedClass("dog", 2L),
new PredictedClass("fox", 1L),
new PredictedClass("mouse", 4L)),
0),
new ActualClass("mouse",
15,
Arrays.asList(
new PredictedClass("ant", 2L),
new PredictedClass("cat", 5L),
new PredictedClass("dog", 4L),
new PredictedClass("fox", 3L),
new PredictedClass("mouse", 1L)),
0))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
} }
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
@ -153,24 +185,22 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
MulticlassConfusionMatrix.Result confusionMatrixResult = MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>(); assertThat(
expectedConfusionMatrix.put("ant", new HashMap<>()); confusionMatrixResult.getConfusionMatrix(),
expectedConfusionMatrix.get("ant").put("ant", 1L); equalTo(Arrays.asList(
expectedConfusionMatrix.get("ant").put("cat", 4L); new ActualClass("ant",
expectedConfusionMatrix.get("ant").put("dog", 3L); 15,
expectedConfusionMatrix.get("ant").put("_other_", 7L); Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)),
expectedConfusionMatrix.put("cat", new HashMap<>()); 7),
expectedConfusionMatrix.get("cat").put("ant", 3L); new ActualClass("cat",
expectedConfusionMatrix.get("cat").put("cat", 1L); 15,
expectedConfusionMatrix.get("cat").put("dog", 5L); Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)),
expectedConfusionMatrix.get("cat").put("_other_", 6L); 6),
expectedConfusionMatrix.put("dog", new HashMap<>()); new ActualClass("dog",
expectedConfusionMatrix.get("dog").put("ant", 4L); 15,
expectedConfusionMatrix.get("dog").put("cat", 2L); Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)),
expectedConfusionMatrix.get("dog").put("dog", 1L); 8))));
expectedConfusionMatrix.get("dog").put("_other_", 8L); assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L));
} }
private static void indexAnimalsData(String indexName) { private static void indexAnimalsData(String indexName) {

View File

@ -618,8 +618,40 @@ setup:
} }
} }
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } - match:
- match: { classification.multiclass_confusion_matrix._other_: 0 } classification.multiclass_confusion_matrix:
confusion_matrix:
- actual_class: "cat"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 2
- predicted_class: "dog"
count: 1
- predicted_class: "mouse"
count: 0
other_predicted_class_doc_count: 0
- actual_class: "dog"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 1
- predicted_class: "dog"
count: 2
- predicted_class: "mouse"
count: 0
other_predicted_class_doc_count: 0
- actual_class: "mouse"
actual_class_doc_count: 2
predicted_classes:
- predicted_class: "cat"
count: 1
- predicted_class: "dog"
count: 0
- predicted_class: "mouse"
count: 1
other_predicted_class_doc_count: 0
other_actual_class_count: 0
--- ---
"Test classification multiclass_confusion_matrix with explicit size": "Test classification multiclass_confusion_matrix with explicit size":
- do: - do:
@ -636,8 +668,26 @@ setup:
} }
} }
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } } - match:
- match: { classification.multiclass_confusion_matrix._other_: 1 } classification.multiclass_confusion_matrix:
confusion_matrix:
- actual_class: "cat"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 2
- predicted_class: "dog"
count: 1
other_predicted_class_doc_count: 0
- actual_class: "dog"
actual_class_doc_count: 3
predicted_classes:
- predicted_class: "cat"
count: 1
- predicted_class: "dog"
count: 2
other_predicted_class_doc_count: 0
other_actual_class_count: 1
--- ---
"Test classification with null metrics": "Test classification with null metrics":
- do: - do:
@ -653,8 +703,7 @@ setup:
} }
} }
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } - is_true: classification.multiclass_confusion_matrix
- match: { classification.multiclass_confusion_matrix._other_: 0 }
--- ---
"Test classification given missing actual_field": "Test classification given missing actual_field":
- do: - do: