[7.x] Implement evaluation API for multiclass classification problem (#47126) (#47343)

This commit is contained in:
Przemysław Witek 2019-10-04 17:54:51 +02:00 committed by GitHub
parent e3aab1295e
commit ee952da2e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1963 additions and 41 deletions

View File

@ -18,7 +18,9 @@
*/
package org.elasticsearch.client.ml.dataframe.evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@ -41,6 +43,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluations
new NamedXContentRegistry.Entry(
Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
// Evaluation metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
@ -48,6 +51,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(MulticlassConfusionMatrixMetric.NAME),
MulticlassConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
new NamedXContentRegistry.Entry(
@ -60,10 +67,14 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent),
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(MulticlassConfusionMatrixMetric.NAME),
MulticlassConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
}
}

View File

@ -0,0 +1,132 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
/**
* Evaluation of classification results.
*/
public class Classification implements Evaluation {
public static final String NAME = "classification";
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField METRICS = new ParseField("metrics");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
}
public static Classification fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/**
* The field containing the actual value
* The value of this field is assumed to be numeric
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
*/
private final String predictedField;
/**
* The list of metrics to calculate
*/
private final List<EvaluationMetric> metrics;
public Classification(String actualField, String predictedField) {
this(actualField, predictedField, (List<EvaluationMetric>)null);
}
public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
this(actualField, predictedField, Arrays.asList(metrics));
}
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
this.actualField = Objects.requireNonNull(actualField);
this.predictedField = Objects.requireNonNull(predictedField);
if (metrics != null) {
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
}
this.metrics = metrics;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
if (metrics != null) {
builder.startObject(METRICS.getPreferredName());
for (EvaluationMetric metric : metrics) {
builder.field(metric.getName(), metric);
}
builder.endObject();
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Classification that = (Classification) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.metrics, this.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedField, metrics);
}
}

View File

@ -0,0 +1,164 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Calculates the multiclass confusion matrix.
*/
public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
public static final String NAME = "multiclass_confusion_matrix";
public static final ParseField SIZE = new ParseField("size");
private static final ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> PARSER = createParser();
private static ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> createParser() {
ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> parser =
new ConstructingObjectParser<>(NAME, true, args -> new MulticlassConfusionMatrixMetric((Integer) args[0]));
parser.declareInt(optionalConstructorArg(), SIZE);
return parser;
}
public static MulticlassConfusionMatrixMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Integer size;
public MulticlassConfusionMatrixMetric() {
this(null);
}
public MulticlassConfusionMatrixMetric(@Nullable Integer size) {
this.size = size;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (size != null) {
builder.field(SIZE.getPreferredName(), size);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MulticlassConfusionMatrixMetric that = (MulticlassConfusionMatrixMetric) o;
return Objects.equals(this.size, that.size);
}
@Override
public int hashCode() {
return Objects.hash(size);
}
public static class Result implements EvaluationMetric.Result {
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
static {
PARSER.declareObject(
constructorArg(),
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
// Immutable
private final Map<String, Map<String, Long>> confusionMatrix;
private final long otherClassesCount;
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
this.otherClassesCount = otherClassesCount;
}
@Override
public String getMetricName() {
return NAME;
}
public Map<String, Map<String, Long>> getConfusionMatrix() {
return confusionMatrix;
}
public long getOtherClassesCount() {
return otherClassesCount;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount;
}
@Override
public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount);
}
}
}

View File

@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -1638,19 +1640,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
String indexName = "evaluate-test-index";
createIndex(indexName, mappingForClassification());
createIndex(indexName, mappingForSoftClassification());
BulkRequest bulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "blue", false, 0.1)) // #0
.add(docForClassification(indexName, "blue", false, 0.2)) // #1
.add(docForClassification(indexName, "blue", false, 0.3)) // #2
.add(docForClassification(indexName, "blue", false, 0.4)) // #3
.add(docForClassification(indexName, "blue", false, 0.7)) // #4
.add(docForClassification(indexName, "blue", true, 0.2)) // #5
.add(docForClassification(indexName, "green", true, 0.3)) // #6
.add(docForClassification(indexName, "green", true, 0.4)) // #7
.add(docForClassification(indexName, "green", true, 0.8)) // #8
.add(docForClassification(indexName, "green", true, 0.9)); // #9
.add(docForSoftClassification(indexName, "blue", false, 0.1)) // #0
.add(docForSoftClassification(indexName, "blue", false, 0.2)) // #1
.add(docForSoftClassification(indexName, "blue", false, 0.3)) // #2
.add(docForSoftClassification(indexName, "blue", false, 0.4)) // #3
.add(docForSoftClassification(indexName, "blue", false, 0.7)) // #4
.add(docForSoftClassification(indexName, "blue", true, 0.2)) // #5
.add(docForSoftClassification(indexName, "green", true, 0.3)) // #6
.add(docForSoftClassification(indexName, "green", true, 0.4)) // #7
.add(docForSoftClassification(indexName, "green", true, 0.8)) // #8
.add(docForSoftClassification(indexName, "green", true, 0.9)); // #9
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@ -1712,19 +1714,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
String indexName = "evaluate-with-query-test-index";
createIndex(indexName, mappingForClassification());
createIndex(indexName, mappingForSoftClassification());
BulkRequest bulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "blue", true, 1.0)) // #0
.add(docForClassification(indexName, "blue", true, 1.0)) // #1
.add(docForClassification(indexName, "blue", true, 1.0)) // #2
.add(docForClassification(indexName, "blue", true, 1.0)) // #3
.add(docForClassification(indexName, "blue", true, 0.0)) // #4
.add(docForClassification(indexName, "blue", true, 0.0)) // #5
.add(docForClassification(indexName, "green", true, 0.0)) // #6
.add(docForClassification(indexName, "green", true, 0.0)) // #7
.add(docForClassification(indexName, "green", true, 0.0)) // #8
.add(docForClassification(indexName, "green", true, 1.0)); // #9
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #0
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #1
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #2
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #3
.add(docForSoftClassification(indexName, "blue", true, 0.0)) // #4
.add(docForSoftClassification(indexName, "blue", true, 0.0)) // #5
.add(docForSoftClassification(indexName, "green", true, 0.0)) // #6
.add(docForSoftClassification(indexName, "green", true, 0.0)) // #7
.add(docForSoftClassification(indexName, "green", true, 0.0)) // #8
.add(docForSoftClassification(indexName, "green", true, 1.0)); // #9
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@ -1787,6 +1789,85 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
}
public void testEvaluateDataFrame_Classification() throws IOException {
String indexName = "evaluate-classification-test-index";
createIndex(indexName, mappingForClassification());
BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "cat", "cat"))
.add(docForClassification(indexName, "cat", "cat"))
.add(docForClassification(indexName, "cat", "cat"))
.add(docForClassification(indexName, "cat", "dog"))
.add(docForClassification(indexName, "cat", "fish"))
.add(docForClassification(indexName, "dog", "cat"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "dog", "dog"))
.add(docForClassification(indexName, "horse", "cat"));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName,
null,
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 3L);
expectedConfusionMatrix.get("cat").put("dog", 1L);
expectedConfusionMatrix.get("cat").put("horse", 0L);
expectedConfusionMatrix.get("cat").put("_other_", 1L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 1L);
expectedConfusionMatrix.get("dog").put("dog", 3L);
expectedConfusionMatrix.get("dog").put("horse", 0L);
expectedConfusionMatrix.put("horse", new HashMap<>());
expectedConfusionMatrix.get("horse").put("cat", 1L);
expectedConfusionMatrix.get("horse").put("dog", 0L);
expectedConfusionMatrix.get("horse").put("horse", 0L);
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
}
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName,
null,
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MulticlassConfusionMatrixMetric.Result mcmResult =
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 3L);
expectedConfusionMatrix.get("cat").put("dog", 1L);
expectedConfusionMatrix.get("cat").put("_other_", 1L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 1L);
expectedConfusionMatrix.get("dog").put("dog", 3L);
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
}
}
private static XContentBuilder defaultMappingForTest() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
@ -1804,7 +1885,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
private static final String actualField = "label";
private static final String probabilityField = "p";
private static XContentBuilder mappingForClassification() throws IOException {
private static XContentBuilder mappingForSoftClassification() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject(datasetField)
@ -1820,26 +1901,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.endObject();
}
private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) {
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
}
private static final String actualClassField = "actual_class";
private static final String predictedClassField = "predicted_class";
private static XContentBuilder mappingForClassification() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject(actualClassField)
.field("type", "keyword")
.endObject()
.startObject(predictedClassField)
.field("type", "keyword")
.endObject()
.endObject()
.endObject();
}
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
}
private static final String actualRegression = "regression_actual";
private static final String probabilityRegression = "regression_prob";
private static XContentBuilder mappingForRegression() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject(actualRegression)
.field("type", "double")
.startObject(actualRegression)
.field("type", "double")
.endObject()
.startObject(probabilityRegression)
.field("type", "double")
.endObject()
.endObject()
.startObject(probabilityRegression)
.field("type", "double")
.endObject()
.endObject()
.endObject();
.endObject();
}
private static IndexRequest docForRegression(String indexName, double act, double p) {
@ -1854,11 +1957,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
public void testEstimateMemoryUsage() throws IOException {
String indexName = "estimate-test-index";
createIndex(indexName, mappingForClassification());
createIndex(indexName, mappingForSoftClassification());
BulkRequest bulk1 = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 10; ++i) {
bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
}
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
@ -1884,7 +1987,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
BulkRequest bulk2 = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 10; i < 100; ++i) {
bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
}
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);

View File

@ -57,7 +57,9 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
import org.elasticsearch.client.indexlifecycle.UnfollowAction;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -684,7 +686,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(45, namedXContents.size());
assertEquals(48, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -724,22 +726,24 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names,
hasItems(AucRocMetric.NAME,
PrecisionMetric.NAME,
RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
MulticlassConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names,
hasItems(AucRocMetric.NAME,
PrecisionMetric.NAME,
RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
MulticlassConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));

View File

@ -0,0 +1,64 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.function.Predicate;
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static Classification createRandom() {
return new Classification(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric()));
}
@Override
protected Classification createTestInstance() {
return createRandom();
}
@Override
protected Classification doParseInstance(XContentParser parser) throws IOException {
return Classification.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// allow unknown fields in the root of the object only
return field -> !field.isEmpty();
}
}

View File

@ -0,0 +1,74 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric.Result> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected MulticlassConfusionMatrixMetric.Result createTestInstance() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
for (int i = 0; i < numClasses; i++) {
Map<String, Long> row = new TreeMap<>();
confusionMatrix.put(classNames.get(i), row);
for (int j = 0; j < numClasses; j++) {
if (randomBoolean()) {
row.put(classNames.get(i), randomNonNegativeLong());
}
}
}
long otherClassesCount = randomNonNegativeLong();
return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount);
}
@Override
protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrixMetric.Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// allow unknown fields in the root of the object only
return field -> !field.isEmpty();
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected MulticlassConfusionMatrixMetric createTestInstance() {
Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null;
return new MulticlassConfusionMatrixMetric(size);
}
@Override
protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrixMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -139,6 +139,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
@ -469,6 +471,14 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
// ML - Data frame evaluation
new NamedWriteableRegistry.Entry(
Evaluation.class,
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification.NAME.getPreferredName(),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification::new),
new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix.Result::new),
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new),
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),

View File

@ -8,7 +8,10 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
@ -32,6 +35,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluations
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
BinarySoftClassification::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
// Soft classification metrics
@ -41,6 +45,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
ConfusionMatrix::fromXContent));
// Classification metrics
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
MulticlassConfusionMatrix::fromXContent));
// Regression metrics
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
@ -54,6 +62,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluations
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
Classification::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
// Evaluation Metrics
@ -65,6 +75,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
Recall::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError::new));
@ -79,6 +92,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
ScoreByThresholdResult::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError.Result::new));

View File

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
/**
* Evaluation of classification results.
*/
public class Classification implements Evaluation {
public static final ParseField NAME = new ParseField("classification");
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField METRICS = new ParseField("metrics");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<ClassificationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS);
}
public static Classification fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/**
* The field containing the actual value
* The value of this field is assumed to be numeric
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
*/
private final String predictedField;
/**
* The list of metrics to calculate
*/
private final List<ClassificationMetric> metrics;
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics);
}
public Classification(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedField = in.readString();
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
}
private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
return metrics;
}
private static List<ClassificationMetric> defaultMetrics() {
return Arrays.asList(new MulticlassConfusionMatrix());
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public List<ClassificationMetric> getMetrics() {
return metrics;
}
@Override
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
SearchSourceBuilder searchSourceBuilder =
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
for (ClassificationMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}
@Override
public void process(SearchResponse searchResponse) {
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
if (searchResponse.getHits().getTotalHits().value == 0) {
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
}
for (ClassificationMetric metric : metrics) {
metric.process(searchResponse.getAggregations());
}
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualField);
out.writeString(predictedField);
out.writeNamedWriteableList(metrics);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
builder.startObject(METRICS.getPreferredName());
for (ClassificationMetric metric : metrics) {
builder.field(metric.getWriteableName(), metric);
}
builder.endObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Classification that = (Classification) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.metrics, this.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedField, metrics);
}
}

View File

@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import java.util.List;
public interface ClassificationMetric extends EvaluationMetric {
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);
/**
* Processes given aggregations as a step towards computing result
* @param aggs aggregations from {@link SearchResponse}
*/
void process(Aggregations aggs);
}

View File

@ -0,0 +1,277 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* {@link MulticlassConfusionMatrix} is a metric that answers the question:
* "How many examples belonging to class X were classified as Y by the classifier?"
* for all the possible class pairs {X, Y}.
*/
public class MulticlassConfusionMatrix implements ClassificationMetric {
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
public static final ParseField SIZE = new ParseField("size");
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
parser.declareInt(optionalConstructorArg(), SIZE);
return parser;
}
public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
private static final String OTHER_BUCKET_KEY = "_other_";
private static final int DEFAULT_SIZE = 10;
private static final int MAX_SIZE = 1000;
private final int size;
private List<String> topActualClassNames;
private Result result;
public MulticlassConfusionMatrix() {
this((Integer) null);
}
public MulticlassConfusionMatrix(@Nullable Integer size) {
if (size != null && (size <= 0 || size > MAX_SIZE)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
}
this.size = size != null ? size : DEFAULT_SIZE;
}
public MulticlassConfusionMatrix(StreamInput in) throws IOException {
this.size = in.readVInt();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getName() {
return NAME.getPreferredName();
}
public int getSize() {
return size;
}
@Override
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
if (topActualClassNames == null) { // This is step 1
return Arrays.asList(
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size));
}
if (result == null) { // This is step 2
KeyedFilter[] keyedFilters =
topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
return Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
.field(actualField),
AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY)));
}
return Collections.emptyList();
}
@Override
public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList());
}
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
Map<String, Map<String, Long>> counts = new TreeMap<>();
for (Terms.Bucket bucket : termsAgg.getBuckets()) {
String actualClass = bucket.getKeyAsString();
Map<String, Long> subCounts = new TreeMap<>();
counts.put(actualClass, subCounts);
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
String predictedClass = subBucket.getKeyAsString();
Long docCount = subBucket.getDocCount();
if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) {
subCounts.put(predictedClass, docCount);
}
}
}
result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size);
}
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(size);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(SIZE.getPreferredName(), size);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
return Objects.equals(this.size, that.size);
}
@Override
public int hashCode() {
return Objects.hash(size);
}
public static class Result implements EvaluationMetricResult {
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
static {
PARSER.declareObject(
constructorArg(),
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
CONFUSION_MATRIX);
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
// Immutable
private final Map<String, Map<String, Long>> confusionMatrix;
private final long otherClassesCount;
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
this.otherClassesCount = otherClassesCount;
}
public Result(StreamInput in) throws IOException {
this.confusionMatrix = Collections.unmodifiableMap(
in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong)));
this.otherClassesCount = in.readLong();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
public Map<String, Map<String, Long>> getConfusionMatrix() {
return confusionMatrix;
}
public long getOtherClassesCount() {
return otherClassesCount;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeMap(
confusionMatrix,
StreamOutput::writeString,
(out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong));
out.writeLong(otherClassesCount);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
&& this.otherClassesCount == that.otherClassesCount;
}
@Override
public int hashCode() {
return Objects.hash(confusionMatrix, otherClassesCount);
}
}
}

View File

@ -11,6 +11,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
@ -29,6 +30,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
String evaluationName = randomAlphaOfLength(10);
List<EvaluationMetricResult> metrics =
Arrays.asList(
MulticlassConfusionMatrixResultTests.createRandom(),
new MeanSquaredError.Result(randomDouble()),
new RSquared.Result(randomDouble()));
int numMetrics = randomIntBetween(0, metrics.size());

View File

@ -0,0 +1,222 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static Classification createRandom() {
return new Classification(
randomAlphaOfLength(10),
randomAlphaOfLength(10),
randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom()));
}
@Override
protected Classification doParseInstance(XContentParser parser) throws IOException {
return Classification.fromXContent(parser);
}
@Override
protected Classification createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Classification> instanceReader() {
return Classification::new;
}
public void testConstructor_GivenEmptyMetrics() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", "bar", Collections.emptyList()));
assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics"));
}
public void testBuildSearch() {
QueryBuilder userProvidedQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
QueryBuilder expectedSearchQuery =
QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("act"))
.filter(QueryBuilders.existsQuery("pred"))
.filter(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("field_A", "some-value"))
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
}
public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() {
ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
assertThat(metric1.getResult(), isEmpty());
assertThat(metric2.getResult(), isEmpty());
assertThat(metric3.getResult(), isEmpty());
assertThat(metric4.getResult(), isEmpty());
assertThat(evaluation.hasAllResults(), is(false));
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
assertThat(metric1.getResult(), isEmpty());
assertThat(metric2.getResult(), isEmpty());
assertThat(metric3.getResult(), isEmpty());
assertThat(metric4.getResult(), isEmpty());
assertThat(evaluation.hasAllResults(), is(false));
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
assertThat(metric1.getResult(), isPresent());
assertThat(metric2.getResult(), isEmpty());
assertThat(metric3.getResult(), isEmpty());
assertThat(metric4.getResult(), isEmpty());
assertThat(evaluation.hasAllResults(), is(false));
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
assertThat(metric1.getResult(), isPresent());
assertThat(metric2.getResult(), isPresent());
assertThat(metric3.getResult(), isEmpty());
assertThat(metric4.getResult(), isEmpty());
assertThat(evaluation.hasAllResults(), is(false));
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
assertThat(metric1.getResult(), isPresent());
assertThat(metric2.getResult(), isPresent());
assertThat(metric3.getResult(), isPresent());
assertThat(metric4.getResult(), isEmpty());
assertThat(evaluation.hasAllResults(), is(false));
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
assertThat(metric1.getResult(), isPresent());
assertThat(metric2.getResult(), isPresent());
assertThat(metric3.getResult(), isPresent());
assertThat(metric4.getResult(), isPresent());
assertThat(evaluation.hasAllResults(), is(true));
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
assertThat(metric1.getResult(), isPresent());
assertThat(metric2.getResult(), isPresent());
assertThat(metric3.getResult(), isPresent());
assertThat(metric4.getResult(), isPresent());
assertThat(evaluation.hasAllResults(), is(true));
}
private static SearchResponse mockSearchResponseWithNonZeroTotalHits() {
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits hits = new SearchHits(SearchHits.EMPTY, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0);
when(searchResponse.getHits()).thenReturn(hits);
return searchResponse;
}
/**
* Metric which iterates through its steps in {@link #process} method.
* Number of steps is configurable.
* Upon reaching the last step, the result is produced.
*/
private static class FakeClassificationMetric implements ClassificationMetric {
private final String name;
private final int numSteps;
private int currentStepIndex;
private EvaluationMetricResult result;
FakeClassificationMetric(String name, int numSteps) {
this.name = name;
this.numSteps = numSteps;
}
@Override
public String getName() {
return name;
}
@Override
public String getWriteableName() {
return name;
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
return Collections.emptyList();
}
@Override
public void process(Aggregations aggs) {
if (result != null) {
return;
}
currentStepIndex++;
if (currentStepIndex == numSteps) {
// This is the last step, time to write evaluation result
result = mock(EvaluationMetricResult.class);
}
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
return builder;
}
@Override
public void writeTo(StreamOutput out) {
}
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix.Result> {
public static MulticlassConfusionMatrix.Result createRandom() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
for (int i = 0; i < numClasses; i++) {
Map<String, Long> row = new TreeMap<>();
confusionMatrix.put(classNames.get(i), row);
for (int j = 0; j < numClasses; j++) {
if (randomBoolean()) {
row.put(classNames.get(i), randomNonNegativeLong());
}
}
}
long otherClassesCount = randomNonNegativeLong();
return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount);
}
@Override
protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrix.Result.fromXContent(parser);
}
@Override
protected MulticlassConfusionMatrix.Result createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<MulticlassConfusionMatrix.Result> instanceReader() {
return MulticlassConfusionMatrix.Result::new;
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// allow unknown fields in the root of the object only
return field -> !field.isEmpty();
}
}

View File

@ -0,0 +1,205 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
@Override
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
return MulticlassConfusionMatrix.fromXContent(parser);
}
@Override
protected MulticlassConfusionMatrix createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<MulticlassConfusionMatrix> instanceReader() {
return MulticlassConfusionMatrix::new;
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
public static MulticlassConfusionMatrix createRandom() {
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
return new MulticlassConfusionMatrix(size);
}
public void testConstructor_SizeValidationFailures() {
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
}
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
}
{
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
}
}
public void testAggs() {
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
assertThat(aggs, is(not(empty())));
assertThat(confusionMatrix.getResult(), equalTo(Optional.empty()));
}
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class",
Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
0L),
mockTerms(
"multiclass_confusion_matrix_step_2_by_actual_class",
Arrays.asList(
mockTermsBucket(
"dog",
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockTermsBucket(
"cat",
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))),
0L),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 10L);
expectedConfusionMatrix.get("dog").put("dog", 20L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 30L);
expectedConfusionMatrix.get("cat").put("dog", 40L);
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(result.getOtherClassesCount(), equalTo(0L));
}
public void testEvaluate_OtherClassesCountGreaterThanZero() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(
"multiclass_confusion_matrix_step_1_by_actual_class",
Arrays.asList(
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
100L),
mockTerms(
"multiclass_confusion_matrix_step_2_by_actual_class",
Arrays.asList(
mockTermsBucket(
"dog",
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
mockTermsBucket(
"cat",
new Aggregations(Arrays.asList(mockFilters(
"multiclass_confusion_matrix_step_2_by_predicted_class",
Arrays.asList(
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))),
100L),
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("cat", 10L);
expectedConfusionMatrix.get("dog").put("dog", 20L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("cat", 30L);
expectedConfusionMatrix.get("cat").put("dog", 40L);
expectedConfusionMatrix.get("cat").put("_other_", 15L);
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(result.getOtherClassesCount(), equalTo(3L));
}
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
Terms aggregation = mock(Terms.class);
when(aggregation.getName()).thenReturn(name);
doReturn(buckets).when(aggregation).getBuckets();
when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
return aggregation;
}
private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) {
Terms.Bucket bucket = mock(Terms.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(actualClass);
when(bucket.getAggregations()).thenReturn(subAggs);
return bucket;
}
private static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
Filters aggregation = mock(Filters.class);
when(aggregation.getName()).thenReturn(name);
doReturn(buckets).when(aggregation).getBuckets();
return aggregation;
}
private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) {
Filters.Bucket bucket = mock(Filters.Bucket.class);
when(bucket.getKeyAsString()).thenReturn(predictedClass);
when(bucket.getDocCount()).thenReturn(docCount);
return bucket;
}
private static Cardinality mockCardinality(String name, long value) {
Cardinality aggregation = mock(Cardinality.class);
when(aggregation.getName()).thenReturn(name);
when(aggregation.getValue()).thenReturn(value);
return aggregation;
}
}

View File

@ -100,6 +100,9 @@ integTest.runner {
'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds',
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
'ml/evaluate_data_frame/Test classification given evaluation with empty metrics',
'ml/evaluate_data_frame/Test classification given missing actual_field',
'ml/evaluate_data_frame/Test classification given missing predicted_field',
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
'ml/evaluate_data_frame/Test regression given missing actual_field',
'ml/evaluate_data_frame/Test regression given missing predicted_field',

View File

@ -0,0 +1,200 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.junit.After;
import org.junit.Before;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
@Before
public void setup() {
indexAnimalsData(ANIMALS_DATA_INDEX);
}
@After
public void cleanup() {
cleanUp();
}
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
new EvaluateDataFrameAction.Request()
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
.setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("ant", new HashMap<>());
expectedConfusionMatrix.get("ant").put("ant", 1L);
expectedConfusionMatrix.get("ant").put("cat", 4L);
expectedConfusionMatrix.get("ant").put("dog", 3L);
expectedConfusionMatrix.get("ant").put("fox", 2L);
expectedConfusionMatrix.get("ant").put("mouse", 5L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("ant", 3L);
expectedConfusionMatrix.get("cat").put("cat", 1L);
expectedConfusionMatrix.get("cat").put("dog", 5L);
expectedConfusionMatrix.get("cat").put("fox", 4L);
expectedConfusionMatrix.get("cat").put("mouse", 2L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("ant", 4L);
expectedConfusionMatrix.get("dog").put("cat", 2L);
expectedConfusionMatrix.get("dog").put("dog", 1L);
expectedConfusionMatrix.get("dog").put("fox", 5L);
expectedConfusionMatrix.get("dog").put("mouse", 3L);
expectedConfusionMatrix.put("fox", new HashMap<>());
expectedConfusionMatrix.get("fox").put("ant", 5L);
expectedConfusionMatrix.get("fox").put("cat", 3L);
expectedConfusionMatrix.get("fox").put("dog", 2L);
expectedConfusionMatrix.get("fox").put("fox", 1L);
expectedConfusionMatrix.get("fox").put("mouse", 4L);
expectedConfusionMatrix.put("mouse", new HashMap<>());
expectedConfusionMatrix.get("mouse").put("ant", 2L);
expectedConfusionMatrix.get("mouse").put("cat", 5L);
expectedConfusionMatrix.get("mouse").put("dog", 4L);
expectedConfusionMatrix.get("mouse").put("fox", 3L);
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
}
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
new EvaluateDataFrameAction.Request()
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
.setEvaluation(
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("ant", new HashMap<>());
expectedConfusionMatrix.get("ant").put("ant", 1L);
expectedConfusionMatrix.get("ant").put("cat", 4L);
expectedConfusionMatrix.get("ant").put("dog", 3L);
expectedConfusionMatrix.get("ant").put("fox", 2L);
expectedConfusionMatrix.get("ant").put("mouse", 5L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("ant", 3L);
expectedConfusionMatrix.get("cat").put("cat", 1L);
expectedConfusionMatrix.get("cat").put("dog", 5L);
expectedConfusionMatrix.get("cat").put("fox", 4L);
expectedConfusionMatrix.get("cat").put("mouse", 2L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("ant", 4L);
expectedConfusionMatrix.get("dog").put("cat", 2L);
expectedConfusionMatrix.get("dog").put("dog", 1L);
expectedConfusionMatrix.get("dog").put("fox", 5L);
expectedConfusionMatrix.get("dog").put("mouse", 3L);
expectedConfusionMatrix.put("fox", new HashMap<>());
expectedConfusionMatrix.get("fox").put("ant", 5L);
expectedConfusionMatrix.get("fox").put("cat", 3L);
expectedConfusionMatrix.get("fox").put("dog", 2L);
expectedConfusionMatrix.get("fox").put("fox", 1L);
expectedConfusionMatrix.get("fox").put("mouse", 4L);
expectedConfusionMatrix.put("mouse", new HashMap<>());
expectedConfusionMatrix.get("mouse").put("ant", 2L);
expectedConfusionMatrix.get("mouse").put("cat", 5L);
expectedConfusionMatrix.get("mouse").put("dog", 4L);
expectedConfusionMatrix.get("mouse").put("fox", 3L);
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
}
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
new EvaluateDataFrameAction.Request()
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
.setEvaluation(
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MulticlassConfusionMatrix.Result confusionMatrixResult =
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
expectedConfusionMatrix.put("ant", new HashMap<>());
expectedConfusionMatrix.get("ant").put("ant", 1L);
expectedConfusionMatrix.get("ant").put("cat", 4L);
expectedConfusionMatrix.get("ant").put("dog", 3L);
expectedConfusionMatrix.get("ant").put("_other_", 7L);
expectedConfusionMatrix.put("cat", new HashMap<>());
expectedConfusionMatrix.get("cat").put("ant", 3L);
expectedConfusionMatrix.get("cat").put("cat", 1L);
expectedConfusionMatrix.get("cat").put("dog", 5L);
expectedConfusionMatrix.get("cat").put("_other_", 6L);
expectedConfusionMatrix.put("dog", new HashMap<>());
expectedConfusionMatrix.get("dog").put("ant", 4L);
expectedConfusionMatrix.get("dog").put("cat", 2L);
expectedConfusionMatrix.get("dog").put("dog", 1L);
expectedConfusionMatrix.get("dog").put("_other_", 8L);
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L));
}
private static void indexAnimalsData(String indexName) {
client().admin().indices().prepareCreate(indexName)
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
.get();
List<String> classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < classNames.size(); i++) {
for (int j = 0; j < classNames.size(); j++) {
for (int k = 0; k < j + 1; k++) {
bulkRequestBuilder.add(
new IndexRequest(indexName)
.source(
ACTUAL_CLASS_FIELD, classNames.get(i),
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
}
}
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
}

View File

@ -11,6 +11,8 @@ setup:
"outlier_score": 0.0,
"regression_field_act": 10.9,
"regression_field_pred": 10.9,
"classification_field_act": "dog",
"classification_field_pred": "dog",
"all_true_field": true,
"all_false_field": false
}
@ -26,6 +28,8 @@ setup:
"outlier_score": 0.2,
"regression_field_act": 12.0,
"regression_field_pred": 9.9,
"classification_field_act": "cat",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
}
@ -41,6 +45,8 @@ setup:
"outlier_score": 0.3,
"regression_field_act": 20.9,
"regression_field_pred": 5.9,
"classification_field_act": "mouse",
"classification_field_pred": "mouse",
"all_true_field": true,
"all_false_field": false
}
@ -56,6 +62,8 @@ setup:
"outlier_score": 0.3,
"regression_field_act": 11.9,
"regression_field_pred": 11.9,
"classification_field_act": "dog",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
}
@ -71,6 +79,8 @@ setup:
"outlier_score": 0.4,
"regression_field_act": 42.9,
"regression_field_pred": 42.9,
"classification_field_act": "cat",
"classification_field_pred": "dog",
"all_true_field": true,
"all_false_field": false
}
@ -86,6 +96,8 @@ setup:
"outlier_score": 0.5,
"regression_field_act": 0.42,
"regression_field_pred": 0.42,
"classification_field_act": "dog",
"classification_field_pred": "dog",
"all_true_field": true,
"all_false_field": false
}
@ -101,6 +113,8 @@ setup:
"outlier_score": 0.9,
"regression_field_act": 1.1235813,
"regression_field_pred": 1.12358,
"classification_field_act": "cat",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
}
@ -116,6 +130,8 @@ setup:
"outlier_score": 0.95,
"regression_field_act": -5.20,
"regression_field_pred": -5.1,
"classification_field_act": "mouse",
"classification_field_pred": "cat",
"all_true_field": true,
"all_false_field": false
}
@ -569,6 +585,108 @@ setup:
}
}
}
---
"Test classification given evaluation with empty metrics":
- do:
catch: /\[classification\] must have one or more metrics/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { }
}
}
}
---
"Test classification multiclass_confusion_matrix":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { "multiclass_confusion_matrix": {} }
}
}
}
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
- match: { classification.multiclass_confusion_matrix._other_: 0 }
---
"Test classification multiclass_confusion_matrix with explicit size":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { "multiclass_confusion_matrix": { "size": 2 } }
}
}
}
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } }
- match: { classification.multiclass_confusion_matrix._other_: 1 }
---
"Test classification with null metrics":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword"
}
}
}
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
- match: { classification.multiclass_confusion_matrix._other_: 0 }
---
"Test classification given missing actual_field":
- do:
catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "missing",
"predicted_field": "classification_field_pred.keyword"
}
}
}
---
"Test classification given missing predicted_field":
- do:
catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "missing"
}
}
}
---
"Test regression given evaluation with empty metrics":
- do: