[7.x] Change format of MulticlassConfusionMatrix result to be more self-explanatory (#48174) (#48294)
This commit is contained in:
parent
178204703a
commit
2db2b945ec
|
@ -21,17 +21,17 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.TreeMap;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
|
@ -97,32 +97,28 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
|
|||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
|
||||
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
|
||||
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
|
||||
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(
|
||||
constructorArg(),
|
||||
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
|
||||
CONFUSION_MATRIX);
|
||||
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
|
||||
PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
// Immutable
|
||||
private final Map<String, Map<String, Long>> confusionMatrix;
|
||||
private final long otherClassesCount;
|
||||
private final List<ActualClass> confusionMatrix;
|
||||
private final Long otherActualClassCount;
|
||||
|
||||
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
|
||||
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
|
||||
this.otherClassesCount = otherClassesCount;
|
||||
public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
|
||||
this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
|
||||
this.otherActualClassCount = otherActualClassCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -130,19 +126,23 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
|
|||
return NAME;
|
||||
}
|
||||
|
||||
public Map<String, Map<String, Long>> getConfusionMatrix() {
|
||||
public List<ActualClass> getConfusionMatrix() {
|
||||
return confusionMatrix;
|
||||
}
|
||||
|
||||
public long getOtherClassesCount() {
|
||||
return otherClassesCount;
|
||||
public Long getOtherActualClassCount() {
|
||||
return otherActualClassCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (confusionMatrix != null) {
|
||||
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
|
||||
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
|
||||
}
|
||||
if (otherActualClassCount != null) {
|
||||
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -153,12 +153,140 @@ public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
|
|||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
|
||||
&& this.otherClassesCount == that.otherClassesCount;
|
||||
&& Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(confusionMatrix, otherClassesCount);
|
||||
return Objects.hash(confusionMatrix, otherActualClassCount);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ActualClass implements ToXContentObject {
|
||||
|
||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
||||
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
|
||||
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_actual_class",
|
||||
true,
|
||||
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
|
||||
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
|
||||
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
|
||||
}
|
||||
|
||||
private final String actualClass;
|
||||
private final Long actualClassDocCount;
|
||||
private final List<PredictedClass> predictedClasses;
|
||||
private final Long otherPredictedClassDocCount;
|
||||
|
||||
public ActualClass(@Nullable String actualClass,
|
||||
@Nullable Long actualClassDocCount,
|
||||
@Nullable List<PredictedClass> predictedClasses,
|
||||
@Nullable Long otherPredictedClassDocCount) {
|
||||
this.actualClass = actualClass;
|
||||
this.actualClassDocCount = actualClassDocCount;
|
||||
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
|
||||
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (actualClass != null) {
|
||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
||||
}
|
||||
if (actualClassDocCount != null) {
|
||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
||||
}
|
||||
if (predictedClasses != null) {
|
||||
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
|
||||
}
|
||||
if (otherPredictedClassDocCount != null) {
|
||||
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ActualClass that = (ActualClass) o;
|
||||
return Objects.equals(this.actualClass, that.actualClass)
|
||||
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
|
||||
&& Objects.equals(this.predictedClasses, that.predictedClasses)
|
||||
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PredictedClass implements ToXContentObject {
|
||||
|
||||
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
|
||||
private static final ParseField COUNT = new ParseField("count");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
|
||||
PARSER.declareLong(optionalConstructorArg(), COUNT);
|
||||
}
|
||||
|
||||
private final String predictedClass;
|
||||
private final Long count;
|
||||
|
||||
public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
|
||||
this.predictedClass = predictedClass;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (predictedClass != null) {
|
||||
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
|
||||
}
|
||||
if (count != null) {
|
||||
builder.field(COUNT.getPreferredName(), count);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PredictedClass that = (PredictedClass) o;
|
||||
return Objects.equals(this.predictedClass, that.predictedClass)
|
||||
&& Objects.equals(this.count, that.count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(predictedClass, count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,6 +127,8 @@ import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
|||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
|
@ -1807,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "horse", "cat"));
|
||||
.add(docForClassification(indexName, "ant", "cat"));
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
@ -1827,22 +1829,26 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
MulticlassConfusionMatrixMetric.Result mcmResult =
|
||||
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
|
||||
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("horse", 0L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 1L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("dog").put("horse", 0L);
|
||||
expectedConfusionMatrix.put("horse", new HashMap<>());
|
||||
expectedConfusionMatrix.get("horse").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("horse").put("dog", 0L);
|
||||
expectedConfusionMatrix.get("horse").put("horse", 0L);
|
||||
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
|
||||
assertThat(
|
||||
mcmResult.getConfusionMatrix(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new ActualClass(
|
||||
"ant",
|
||||
1L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
|
||||
0L),
|
||||
new ActualClass(
|
||||
"cat",
|
||||
5L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
|
||||
1L),
|
||||
new ActualClass(
|
||||
"dog",
|
||||
4L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
||||
0L))));
|
||||
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
|
@ -1859,16 +1865,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
MulticlassConfusionMatrixMetric.Result mcmResult =
|
||||
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
|
||||
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 1L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 3L);
|
||||
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
|
||||
assertThat(
|
||||
mcmResult.getConfusionMatrix(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new ActualClass("cat", 5L, Arrays.asList(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
|
||||
new ActualClass("dog", 4L, Arrays.asList(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
|
||||
)));
|
||||
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -142,6 +142,8 @@ import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
|
@ -3355,33 +3357,31 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
|
||||
|
||||
Map<String, Map<String, Long>> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
|
||||
long otherClassesCount = multiclassConfusionMatrix.getOtherClassesCount(); // <3>
|
||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
|
||||
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
|
||||
// end::evaluate-data-frame-results-classification
|
||||
|
||||
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
assertThat(
|
||||
confusionMatrix,
|
||||
equalTo(
|
||||
new HashMap<String, Map<String, Long>>() {{
|
||||
put("cat", new HashMap<String, Long>() {{
|
||||
put("cat", 3L);
|
||||
put("dog", 1L);
|
||||
put("ant", 0L);
|
||||
put("_other_", 1L);
|
||||
}});
|
||||
put("dog", new HashMap<String, Long>() {{
|
||||
put("cat", 1L);
|
||||
put("dog", 3L);
|
||||
put("ant", 0L);
|
||||
}});
|
||||
put("ant", new HashMap<String, Long>() {{
|
||||
put("cat", 1L);
|
||||
put("dog", 0L);
|
||||
put("ant", 0L);
|
||||
}});
|
||||
}}));
|
||||
assertThat(otherClassesCount, equalTo(0L));
|
||||
Arrays.asList(
|
||||
new ActualClass(
|
||||
"ant",
|
||||
1L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
|
||||
0L),
|
||||
new ActualClass(
|
||||
"cat",
|
||||
5L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
|
||||
1L),
|
||||
new ActualClass(
|
||||
"dog",
|
||||
4L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
||||
0L))));
|
||||
assertThat(otherActualClassCount, equalTo(0L));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,19 +19,21 @@
|
|||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.Result;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric.Result> {
|
||||
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<Result> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
|
@ -39,26 +41,28 @@ public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContent
|
|||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric.Result createTestInstance() {
|
||||
protected Result createTestInstance() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
Map<String, Long> row = new TreeMap<>();
|
||||
confusionMatrix.put(classNames.get(i), row);
|
||||
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
|
||||
for (int j = 0; j < numClasses; j++) {
|
||||
if (randomBoolean()) {
|
||||
row.put(classNames.get(i), randomNonNegativeLong());
|
||||
predictedClasses.add(new PredictedClass(classNames.get(j), randomBoolean() ? randomNonNegativeLong() : null));
|
||||
}
|
||||
actualClasses.add(
|
||||
new ActualClass(
|
||||
classNames.get(i),
|
||||
randomBoolean() ? randomNonNegativeLong() : null,
|
||||
predictedClasses,
|
||||
randomBoolean() ? randomNonNegativeLong() : null));
|
||||
}
|
||||
}
|
||||
long otherClassesCount = randomNonNegativeLong();
|
||||
return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount);
|
||||
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrixMetric.Result.fromXContent(parser);
|
||||
protected Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -9,7 +9,9 @@ import org.elasticsearch.common.Nullable;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
|
@ -26,14 +28,14 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.TreeMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static java.util.Comparator.comparing;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
|
@ -112,18 +114,19 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
|||
.size(size));
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
KeyedFilter[] keyedFilters =
|
||||
KeyedFilter[] keyedFiltersActual =
|
||||
topActualClassNames.stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
|
||||
.field(actualField),
|
||||
AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size)
|
||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters)
|
||||
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
|
||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
|
||||
.otherBucket(true)
|
||||
.otherBucketKey(OTHER_BUCKET_KEY)));
|
||||
}
|
||||
|
@ -134,26 +137,31 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
|||
public void process(Aggregations aggs) {
|
||||
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
||||
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList());
|
||||
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
|
||||
}
|
||||
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
||||
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
|
||||
Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
Map<String, Map<String, Long>> counts = new TreeMap<>();
|
||||
for (Terms.Bucket bucket : termsAgg.getBuckets()) {
|
||||
Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
|
||||
for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
|
||||
String actualClass = bucket.getKeyAsString();
|
||||
Map<String, Long> subCounts = new TreeMap<>();
|
||||
counts.put(actualClass, subCounts);
|
||||
long actualClassDocCount = bucket.getDocCount();
|
||||
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
|
||||
List<PredictedClass> predictedClasses = new ArrayList<>();
|
||||
long otherPredictedClassDocCount = 0;
|
||||
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
|
||||
String predictedClass = subBucket.getKeyAsString();
|
||||
Long docCount = subBucket.getDocCount();
|
||||
if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) {
|
||||
subCounts.put(predictedClass, docCount);
|
||||
long docCount = subBucket.getDocCount();
|
||||
if (OTHER_BUCKET_KEY.equals(predictedClass)) {
|
||||
otherPredictedClassDocCount = docCount;
|
||||
} else {
|
||||
predictedClasses.add(new PredictedClass(predictedClass, docCount));
|
||||
}
|
||||
}
|
||||
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
|
||||
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
|
||||
}
|
||||
result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size);
|
||||
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,37 +199,35 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
|||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
|
||||
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
|
||||
private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
|
||||
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (long) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(
|
||||
constructorArg(),
|
||||
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
|
||||
CONFUSION_MATRIX);
|
||||
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
|
||||
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
|
||||
PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
// Immutable
|
||||
private final Map<String, Map<String, Long>> confusionMatrix;
|
||||
private final long otherClassesCount;
|
||||
/** List of actual classes. */
|
||||
private final List<ActualClass> actualClasses;
|
||||
/** Number of actual classes that were not included in the confusion matrix because there were too many of them. */
|
||||
private final long otherActualClassCount;
|
||||
|
||||
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
|
||||
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
|
||||
this.otherClassesCount = otherClassesCount;
|
||||
public Result(List<ActualClass> actualClasses, long otherActualClassCount) {
|
||||
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX));
|
||||
this.otherActualClassCount = requireNonNegative(otherActualClassCount, OTHER_ACTUAL_CLASS_COUNT);
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.confusionMatrix = Collections.unmodifiableMap(
|
||||
in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong)));
|
||||
this.otherClassesCount = in.readLong();
|
||||
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
|
||||
this.otherActualClassCount = in.readVLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -234,28 +240,25 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public Map<String, Map<String, Long>> getConfusionMatrix() {
|
||||
return confusionMatrix;
|
||||
public List<ActualClass> getConfusionMatrix() {
|
||||
return actualClasses;
|
||||
}
|
||||
|
||||
public long getOtherClassesCount() {
|
||||
return otherClassesCount;
|
||||
public long getOtherActualClassCount() {
|
||||
return otherActualClassCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeMap(
|
||||
confusionMatrix,
|
||||
StreamOutput::writeString,
|
||||
(out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong));
|
||||
out.writeLong(otherClassesCount);
|
||||
out.writeList(actualClasses);
|
||||
out.writeVLong(otherActualClassCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
|
||||
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
|
||||
builder.field(CONFUSION_MATRIX.getPreferredName(), actualClasses);
|
||||
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -265,13 +268,163 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
|
||||
&& this.otherClassesCount == that.otherClassesCount;
|
||||
return Objects.equals(this.actualClasses, that.actualClasses)
|
||||
&& this.otherActualClassCount == that.otherActualClassCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(confusionMatrix, otherClassesCount);
|
||||
return Objects.hash(actualClasses, otherActualClassCount);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ActualClass implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
||||
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
||||
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
|
||||
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_actual_class",
|
||||
true,
|
||||
a -> new ActualClass((String) a[0], (long) a[1], (List<PredictedClass>) a[2], (long) a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
||||
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
||||
PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
|
||||
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
|
||||
}
|
||||
|
||||
/** Name of the actual class. */
|
||||
private final String actualClass;
|
||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
||||
private final long actualClassDocCount;
|
||||
/** List of predicted classes. */
|
||||
private final List<PredictedClass> predictedClasses;
|
||||
/** Number of documents that were not predicted as any of the {@code predictedClasses}. */
|
||||
private final long otherPredictedClassDocCount;
|
||||
|
||||
public ActualClass(
|
||||
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
|
||||
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
|
||||
this.actualClassDocCount = requireNonNegative(actualClassDocCount, ACTUAL_CLASS_DOC_COUNT);
|
||||
this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES));
|
||||
this.otherPredictedClassDocCount = requireNonNegative(otherPredictedClassDocCount, OTHER_PREDICTED_CLASS_DOC_COUNT);
|
||||
}
|
||||
|
||||
public ActualClass(StreamInput in) throws IOException {
|
||||
this.actualClass = in.readString();
|
||||
this.actualClassDocCount = in.readVLong();
|
||||
this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new));
|
||||
this.otherPredictedClassDocCount = in.readVLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualClass);
|
||||
out.writeVLong(actualClassDocCount);
|
||||
out.writeList(predictedClasses);
|
||||
out.writeVLong(otherPredictedClassDocCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
||||
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
||||
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
|
||||
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ActualClass that = (ActualClass) o;
|
||||
return Objects.equals(this.actualClass, that.actualClass)
|
||||
&& this.actualClassDocCount == that.actualClassDocCount
|
||||
&& Objects.equals(this.predictedClasses, that.predictedClasses)
|
||||
&& this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PredictedClass implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class");
|
||||
private static final ParseField COUNT = new ParseField("count");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), PREDICTED_CLASS);
|
||||
PARSER.declareLong(constructorArg(), COUNT);
|
||||
}
|
||||
|
||||
private final String predictedClass;
|
||||
private final long count;
|
||||
|
||||
public PredictedClass(String predictedClass, long count) {
|
||||
this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS);
|
||||
this.count = requireNonNegative(count, COUNT);
|
||||
}
|
||||
|
||||
public PredictedClass(StreamInput in) throws IOException {
|
||||
this.predictedClass = in.readString();
|
||||
this.count = in.readVLong();
|
||||
}
|
||||
|
||||
public String getPredictedClass() {
|
||||
return predictedClass;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(predictedClass);
|
||||
out.writeVLong(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
|
||||
builder.field(COUNT.getPreferredName(), count);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PredictedClass that = (PredictedClass) o;
|
||||
return Objects.equals(this.predictedClass, that.predictedClass)
|
||||
&& this.count == that.count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(predictedClass, count);
|
||||
}
|
||||
}
|
||||
|
||||
private static long requireNonNegative(long value, ParseField field) {
|
||||
if (value < 0) {
|
||||
throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,50 +5,53 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix.Result> {
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public static MulticlassConfusionMatrix.Result createRandom() {
|
||||
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<Result> {
|
||||
|
||||
public static Result createRandom() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
Map<String, Long> row = new TreeMap<>();
|
||||
confusionMatrix.put(classNames.get(i), row);
|
||||
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
|
||||
for (int j = 0; j < numClasses; j++) {
|
||||
if (randomBoolean()) {
|
||||
row.put(classNames.get(i), randomNonNegativeLong());
|
||||
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
|
||||
}
|
||||
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong()));
|
||||
}
|
||||
}
|
||||
long otherClassesCount = randomNonNegativeLong();
|
||||
return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount);
|
||||
return new Result(actualClasses, randomNonNegativeLong());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrix.Result.fromXContent(parser);
|
||||
protected Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix.Result createTestInstance() {
|
||||
protected Result createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<MulticlassConfusionMatrix.Result> instanceReader() {
|
||||
return MulticlassConfusionMatrix.Result::new;
|
||||
protected Writeable.Reader<Result> instanceReader() {
|
||||
return Result::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -61,4 +64,67 @@ public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTes
|
|||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
public void testConstructor_ValidationFailures() {
|
||||
{
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new Result(null, 0));
|
||||
assertThat(e.getMessage(), equalTo("[confusion_matrix] must not be null."));
|
||||
}
|
||||
{
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new Result(Collections.emptyList(), -1));
|
||||
assertThat(e.status().getStatus(), equalTo(500));
|
||||
assertThat(e.getMessage(), equalTo("[other_actual_class_count] must be >= 0, was: -1"));
|
||||
}
|
||||
{
|
||||
IllegalArgumentException e =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Result(Collections.singletonList(new ActualClass(null, 0, Collections.emptyList(), 0)), 0));
|
||||
assertThat(e.getMessage(), equalTo("[actual_class] must not be null."));
|
||||
}
|
||||
{
|
||||
ElasticsearchException e =
|
||||
expectThrows(
|
||||
ElasticsearchException.class,
|
||||
() -> new Result(Collections.singletonList(new ActualClass("actual_class", -1, Collections.emptyList(), 0)), 0));
|
||||
assertThat(e.status().getStatus(), equalTo(500));
|
||||
assertThat(e.getMessage(), equalTo("[actual_class_doc_count] must be >= 0, was: -1"));
|
||||
}
|
||||
{
|
||||
IllegalArgumentException e =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, null, 0)), 0));
|
||||
assertThat(e.getMessage(), equalTo("[predicted_classes] must not be null."));
|
||||
}
|
||||
{
|
||||
ElasticsearchException e =
|
||||
expectThrows(
|
||||
ElasticsearchException.class,
|
||||
() -> new Result(Collections.singletonList(new ActualClass("actual_class", 0, Collections.emptyList(), -1)), 0));
|
||||
assertThat(e.status().getStatus(), equalTo(500));
|
||||
assertThat(e.getMessage(), equalTo("[other_predicted_class_doc_count] must be >= 0, was: -1"));
|
||||
}
|
||||
{
|
||||
IllegalArgumentException e =
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Result(
|
||||
Collections.singletonList(
|
||||
new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass(null, 0)), 0)),
|
||||
0));
|
||||
assertThat(e.getMessage(), equalTo("[predicted_class] must not be null."));
|
||||
}
|
||||
{
|
||||
ElasticsearchException e =
|
||||
expectThrows(
|
||||
ElasticsearchException.class,
|
||||
() -> new Result(
|
||||
Collections.singletonList(
|
||||
new ActualClass("actual_class", 0, Collections.singletonList(new PredictedClass("predicted_class", -1)), 0)),
|
||||
0));
|
||||
assertThat(e.status().getStatus(), equalTo(500));
|
||||
assertThat(e.getMessage(), equalTo("[count] must be >= 0, was: -1"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,13 +14,13 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
|||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
|
@ -88,22 +88,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
0L),
|
||||
mockTerms(
|
||||
mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
||||
Arrays.asList(
|
||||
mockTermsBucket(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
30,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockTermsBucket(
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
70,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))),
|
||||
0L),
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
|
||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
|
@ -112,15 +113,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 10L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 20L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 30L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 40L);
|
||||
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(result.getOtherClassesCount(), equalTo(0L));
|
||||
assertThat(
|
||||
result.getConfusionMatrix(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0),
|
||||
new ActualClass("cat", 70, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 0))));
|
||||
assertThat(result.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
||||
|
@ -131,22 +130,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockTerms(
|
||||
mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
||||
Arrays.asList(
|
||||
mockTermsBucket(
|
||||
mockFiltersBucket(
|
||||
"dog",
|
||||
30,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockTermsBucket(
|
||||
mockFiltersBucket(
|
||||
"cat",
|
||||
85,
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))),
|
||||
100L),
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
|
||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
|
@ -155,16 +155,13 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 10L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 20L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 30L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 40L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 15L);
|
||||
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(result.getOtherClassesCount(), equalTo(3L));
|
||||
assertThat(
|
||||
result.getConfusionMatrix(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new ActualClass("dog", 30, Arrays.asList(new PredictedClass("cat", 10L), new PredictedClass("dog", 20L)), 0),
|
||||
new ActualClass("cat", 85, Arrays.asList(new PredictedClass("cat", 30L), new PredictedClass("dog", 40L)), 15))));
|
||||
assertThat(result.getOtherActualClassCount(), equalTo(3L));
|
||||
}
|
||||
|
||||
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
|
||||
|
@ -175,9 +172,9 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
return aggregation;
|
||||
}
|
||||
|
||||
private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) {
|
||||
private static Terms.Bucket mockTermsBucket(String key, Aggregations subAggs) {
|
||||
Terms.Bucket bucket = mock(Terms.Bucket.class);
|
||||
when(bucket.getKeyAsString()).thenReturn(actualClass);
|
||||
when(bucket.getKeyAsString()).thenReturn(key);
|
||||
when(bucket.getAggregations()).thenReturn(subAggs);
|
||||
return bucket;
|
||||
}
|
||||
|
@ -189,9 +186,15 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
|||
return aggregation;
|
||||
}
|
||||
|
||||
private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) {
|
||||
private static Filters.Bucket mockFiltersBucket(String key, long docCount, Aggregations subAggs) {
|
||||
Filters.Bucket bucket = mockFiltersBucket(key, docCount);
|
||||
when(bucket.getAggregations()).thenReturn(subAggs);
|
||||
return bucket;
|
||||
}
|
||||
|
||||
private static Filters.Bucket mockFiltersBucket(String key, long docCount) {
|
||||
Filters.Bucket bucket = mock(Filters.Bucket.class);
|
||||
when(bucket.getKeyAsString()).thenReturn(predictedClass);
|
||||
when(bucket.getKeyAsString()).thenReturn(key);
|
||||
when(bucket.getDocCount()).thenReturn(docCount);
|
||||
return bucket;
|
||||
}
|
||||
|
|
|
@ -12,13 +12,13 @@ import org.elasticsearch.action.support.WriteRequest;
|
|||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
|
@ -53,39 +53,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("ant", new HashMap<>());
|
||||
expectedConfusionMatrix.get("ant").put("ant", 1L);
|
||||
expectedConfusionMatrix.get("ant").put("cat", 4L);
|
||||
expectedConfusionMatrix.get("ant").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("ant").put("fox", 2L);
|
||||
expectedConfusionMatrix.get("ant").put("mouse", 5L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("ant", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 5L);
|
||||
expectedConfusionMatrix.get("cat").put("fox", 4L);
|
||||
expectedConfusionMatrix.get("cat").put("mouse", 2L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("ant", 4L);
|
||||
expectedConfusionMatrix.get("dog").put("cat", 2L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("fox", 5L);
|
||||
expectedConfusionMatrix.get("dog").put("mouse", 3L);
|
||||
expectedConfusionMatrix.put("fox", new HashMap<>());
|
||||
expectedConfusionMatrix.get("fox").put("ant", 5L);
|
||||
expectedConfusionMatrix.get("fox").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("fox").put("dog", 2L);
|
||||
expectedConfusionMatrix.get("fox").put("fox", 1L);
|
||||
expectedConfusionMatrix.get("fox").put("mouse", 4L);
|
||||
expectedConfusionMatrix.put("mouse", new HashMap<>());
|
||||
expectedConfusionMatrix.get("mouse").put("ant", 2L);
|
||||
expectedConfusionMatrix.get("mouse").put("cat", 5L);
|
||||
expectedConfusionMatrix.get("mouse").put("dog", 4L);
|
||||
expectedConfusionMatrix.get("mouse").put("fox", 3L);
|
||||
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
|
||||
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
|
||||
assertThat(
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new ActualClass("ant",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 1L),
|
||||
new PredictedClass("cat", 4L),
|
||||
new PredictedClass("dog", 3L),
|
||||
new PredictedClass("fox", 2L),
|
||||
new PredictedClass("mouse", 5L)),
|
||||
0),
|
||||
new ActualClass("cat",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 3L),
|
||||
new PredictedClass("cat", 1L),
|
||||
new PredictedClass("dog", 5L),
|
||||
new PredictedClass("fox", 4L),
|
||||
new PredictedClass("mouse", 2L)),
|
||||
0),
|
||||
new ActualClass("dog",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 4L),
|
||||
new PredictedClass("cat", 2L),
|
||||
new PredictedClass("dog", 1L),
|
||||
new PredictedClass("fox", 5L),
|
||||
new PredictedClass("mouse", 3L)),
|
||||
0),
|
||||
new ActualClass("fox",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 5L),
|
||||
new PredictedClass("cat", 3L),
|
||||
new PredictedClass("dog", 2L),
|
||||
new PredictedClass("fox", 1L),
|
||||
new PredictedClass("mouse", 4L)),
|
||||
0),
|
||||
new ActualClass("mouse",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 2L),
|
||||
new PredictedClass("cat", 5L),
|
||||
new PredictedClass("dog", 4L),
|
||||
new PredictedClass("fox", 3L),
|
||||
new PredictedClass("mouse", 1L)),
|
||||
0))));
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
|
||||
|
@ -103,39 +119,55 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("ant", new HashMap<>());
|
||||
expectedConfusionMatrix.get("ant").put("ant", 1L);
|
||||
expectedConfusionMatrix.get("ant").put("cat", 4L);
|
||||
expectedConfusionMatrix.get("ant").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("ant").put("fox", 2L);
|
||||
expectedConfusionMatrix.get("ant").put("mouse", 5L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("ant", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 5L);
|
||||
expectedConfusionMatrix.get("cat").put("fox", 4L);
|
||||
expectedConfusionMatrix.get("cat").put("mouse", 2L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("ant", 4L);
|
||||
expectedConfusionMatrix.get("dog").put("cat", 2L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("fox", 5L);
|
||||
expectedConfusionMatrix.get("dog").put("mouse", 3L);
|
||||
expectedConfusionMatrix.put("fox", new HashMap<>());
|
||||
expectedConfusionMatrix.get("fox").put("ant", 5L);
|
||||
expectedConfusionMatrix.get("fox").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("fox").put("dog", 2L);
|
||||
expectedConfusionMatrix.get("fox").put("fox", 1L);
|
||||
expectedConfusionMatrix.get("fox").put("mouse", 4L);
|
||||
expectedConfusionMatrix.put("mouse", new HashMap<>());
|
||||
expectedConfusionMatrix.get("mouse").put("ant", 2L);
|
||||
expectedConfusionMatrix.get("mouse").put("cat", 5L);
|
||||
expectedConfusionMatrix.get("mouse").put("dog", 4L);
|
||||
expectedConfusionMatrix.get("mouse").put("fox", 3L);
|
||||
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
|
||||
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
|
||||
assertThat(
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new ActualClass("ant",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 1L),
|
||||
new PredictedClass("cat", 4L),
|
||||
new PredictedClass("dog", 3L),
|
||||
new PredictedClass("fox", 2L),
|
||||
new PredictedClass("mouse", 5L)),
|
||||
0),
|
||||
new ActualClass("cat",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 3L),
|
||||
new PredictedClass("cat", 1L),
|
||||
new PredictedClass("dog", 5L),
|
||||
new PredictedClass("fox", 4L),
|
||||
new PredictedClass("mouse", 2L)),
|
||||
0),
|
||||
new ActualClass("dog",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 4L),
|
||||
new PredictedClass("cat", 2L),
|
||||
new PredictedClass("dog", 1L),
|
||||
new PredictedClass("fox", 5L),
|
||||
new PredictedClass("mouse", 3L)),
|
||||
0),
|
||||
new ActualClass("fox",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 5L),
|
||||
new PredictedClass("cat", 3L),
|
||||
new PredictedClass("dog", 2L),
|
||||
new PredictedClass("fox", 1L),
|
||||
new PredictedClass("mouse", 4L)),
|
||||
0),
|
||||
new ActualClass("mouse",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 2L),
|
||||
new PredictedClass("cat", 5L),
|
||||
new PredictedClass("dog", 4L),
|
||||
new PredictedClass("fox", 3L),
|
||||
new PredictedClass("mouse", 1L)),
|
||||
0))));
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
|
||||
|
@ -153,24 +185,22 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
|||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("ant", new HashMap<>());
|
||||
expectedConfusionMatrix.get("ant").put("ant", 1L);
|
||||
expectedConfusionMatrix.get("ant").put("cat", 4L);
|
||||
expectedConfusionMatrix.get("ant").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("ant").put("_other_", 7L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("ant", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 5L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 6L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("ant", 4L);
|
||||
expectedConfusionMatrix.get("dog").put("cat", 2L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("_other_", 8L);
|
||||
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L));
|
||||
assertThat(
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new ActualClass("ant",
|
||||
15,
|
||||
Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)),
|
||||
7),
|
||||
new ActualClass("cat",
|
||||
15,
|
||||
Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)),
|
||||
6),
|
||||
new ActualClass("dog",
|
||||
15,
|
||||
Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)),
|
||||
8))));
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
|
||||
}
|
||||
|
||||
private static void indexAnimalsData(String indexName) {
|
||||
|
|
|
@ -618,8 +618,40 @@ setup:
|
|||
}
|
||||
}
|
||||
|
||||
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
|
||||
- match: { classification.multiclass_confusion_matrix._other_: 0 }
|
||||
- match:
|
||||
classification.multiclass_confusion_matrix:
|
||||
confusion_matrix:
|
||||
- actual_class: "cat"
|
||||
actual_class_doc_count: 3
|
||||
predicted_classes:
|
||||
- predicted_class: "cat"
|
||||
count: 2
|
||||
- predicted_class: "dog"
|
||||
count: 1
|
||||
- predicted_class: "mouse"
|
||||
count: 0
|
||||
other_predicted_class_doc_count: 0
|
||||
- actual_class: "dog"
|
||||
actual_class_doc_count: 3
|
||||
predicted_classes:
|
||||
- predicted_class: "cat"
|
||||
count: 1
|
||||
- predicted_class: "dog"
|
||||
count: 2
|
||||
- predicted_class: "mouse"
|
||||
count: 0
|
||||
other_predicted_class_doc_count: 0
|
||||
- actual_class: "mouse"
|
||||
actual_class_doc_count: 2
|
||||
predicted_classes:
|
||||
- predicted_class: "cat"
|
||||
count: 1
|
||||
- predicted_class: "dog"
|
||||
count: 0
|
||||
- predicted_class: "mouse"
|
||||
count: 1
|
||||
other_predicted_class_doc_count: 0
|
||||
other_actual_class_count: 0
|
||||
---
|
||||
"Test classification multiclass_confusion_matrix with explicit size":
|
||||
- do:
|
||||
|
@ -636,8 +668,26 @@ setup:
|
|||
}
|
||||
}
|
||||
|
||||
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } }
|
||||
- match: { classification.multiclass_confusion_matrix._other_: 1 }
|
||||
- match:
|
||||
classification.multiclass_confusion_matrix:
|
||||
confusion_matrix:
|
||||
- actual_class: "cat"
|
||||
actual_class_doc_count: 3
|
||||
predicted_classes:
|
||||
- predicted_class: "cat"
|
||||
count: 2
|
||||
- predicted_class: "dog"
|
||||
count: 1
|
||||
other_predicted_class_doc_count: 0
|
||||
- actual_class: "dog"
|
||||
actual_class_doc_count: 3
|
||||
predicted_classes:
|
||||
- predicted_class: "cat"
|
||||
count: 1
|
||||
- predicted_class: "dog"
|
||||
count: 2
|
||||
other_predicted_class_doc_count: 0
|
||||
other_actual_class_count: 1
|
||||
---
|
||||
"Test classification with null metrics":
|
||||
- do:
|
||||
|
@ -653,8 +703,7 @@ setup:
|
|||
}
|
||||
}
|
||||
|
||||
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
|
||||
- match: { classification.multiclass_confusion_matrix._other_: 0 }
|
||||
- is_true: classification.multiclass_confusion_matrix
|
||||
---
|
||||
"Test classification given missing actual_field":
|
||||
- do:
|
||||
|
|
Loading…
Reference in New Issue