mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-25 17:38:44 +00:00
This commit is contained in:
parent
e3aab1295e
commit
ee952da2e2
@ -18,7 +18,9 @@
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
@ -41,6 +43,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
// Evaluations
|
||||
new NamedXContentRegistry.Entry(
|
||||
Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
|
||||
// Evaluation metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
|
||||
@ -48,6 +51,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(MulticlassConfusionMatrixMetric.NAME),
|
||||
MulticlassConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
@ -60,10 +67,14 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent),
|
||||
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(MulticlassConfusionMatrixMetric.NAME),
|
||||
MulticlassConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
|
||||
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Evaluation of classification results.
|
||||
*/
|
||||
public class Classification implements Evaluation {
|
||||
|
||||
public static final String NAME = "classification";
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
|
||||
}
|
||||
|
||||
public static Classification fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Classification(String actualField, String predictedField) {
|
||||
this(actualField, predictedField, (List<EvaluationMetric>)null);
|
||||
}
|
||||
|
||||
public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
|
||||
this(actualField, predictedField, Arrays.asList(metrics));
|
||||
}
|
||||
|
||||
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = Objects.requireNonNull(actualField);
|
||||
this.predictedField = Objects.requireNonNull(predictedField);
|
||||
if (metrics != null) {
|
||||
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
|
||||
}
|
||||
this.metrics = metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
if (metrics != null) {
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Classification that = (Classification) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
}
|
||||
}
|
@ -0,0 +1,164 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.TreeMap;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* Calculates the multiclass confusion matrix.
|
||||
*/
|
||||
public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "multiclass_confusion_matrix";
|
||||
|
||||
public static final ParseField SIZE = new ParseField("size");
|
||||
|
||||
private static final ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> PARSER = createParser();
|
||||
|
||||
private static ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> createParser() {
|
||||
ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> parser =
|
||||
new ConstructingObjectParser<>(NAME, true, args -> new MulticlassConfusionMatrixMetric((Integer) args[0]));
|
||||
parser.declareInt(optionalConstructorArg(), SIZE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static MulticlassConfusionMatrixMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final Integer size;
|
||||
|
||||
public MulticlassConfusionMatrixMetric() {
|
||||
this(null);
|
||||
}
|
||||
|
||||
public MulticlassConfusionMatrixMetric(@Nullable Integer size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (size != null) {
|
||||
builder.field(SIZE.getPreferredName(), size);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MulticlassConfusionMatrixMetric that = (MulticlassConfusionMatrixMetric) o;
|
||||
return Objects.equals(this.size, that.size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(size);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
|
||||
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(
|
||||
constructorArg(),
|
||||
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
|
||||
CONFUSION_MATRIX);
|
||||
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
// Immutable
|
||||
private final Map<String, Map<String, Long>> confusionMatrix;
|
||||
private final long otherClassesCount;
|
||||
|
||||
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
|
||||
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
|
||||
this.otherClassesCount = otherClassesCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public Map<String, Map<String, Long>> getConfusionMatrix() {
|
||||
return confusionMatrix;
|
||||
}
|
||||
|
||||
public long getOtherClassesCount() {
|
||||
return otherClassesCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
|
||||
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
|
||||
&& this.otherClassesCount == that.otherClassesCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(confusionMatrix, otherClassesCount);
|
||||
}
|
||||
}
|
||||
}
|
@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
@ -1638,19 +1640,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
|
||||
public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
|
||||
String indexName = "evaluate-test-index";
|
||||
createIndex(indexName, mappingForClassification());
|
||||
createIndex(indexName, mappingForSoftClassification());
|
||||
BulkRequest bulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, "blue", false, 0.1)) // #0
|
||||
.add(docForClassification(indexName, "blue", false, 0.2)) // #1
|
||||
.add(docForClassification(indexName, "blue", false, 0.3)) // #2
|
||||
.add(docForClassification(indexName, "blue", false, 0.4)) // #3
|
||||
.add(docForClassification(indexName, "blue", false, 0.7)) // #4
|
||||
.add(docForClassification(indexName, "blue", true, 0.2)) // #5
|
||||
.add(docForClassification(indexName, "green", true, 0.3)) // #6
|
||||
.add(docForClassification(indexName, "green", true, 0.4)) // #7
|
||||
.add(docForClassification(indexName, "green", true, 0.8)) // #8
|
||||
.add(docForClassification(indexName, "green", true, 0.9)); // #9
|
||||
.add(docForSoftClassification(indexName, "blue", false, 0.1)) // #0
|
||||
.add(docForSoftClassification(indexName, "blue", false, 0.2)) // #1
|
||||
.add(docForSoftClassification(indexName, "blue", false, 0.3)) // #2
|
||||
.add(docForSoftClassification(indexName, "blue", false, 0.4)) // #3
|
||||
.add(docForSoftClassification(indexName, "blue", false, 0.7)) // #4
|
||||
.add(docForSoftClassification(indexName, "blue", true, 0.2)) // #5
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.3)) // #6
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.4)) // #7
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.8)) // #8
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.9)); // #9
|
||||
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
@ -1712,19 +1714,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
|
||||
public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
|
||||
String indexName = "evaluate-with-query-test-index";
|
||||
createIndex(indexName, mappingForClassification());
|
||||
createIndex(indexName, mappingForSoftClassification());
|
||||
BulkRequest bulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #0
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #1
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #2
|
||||
.add(docForClassification(indexName, "blue", true, 1.0)) // #3
|
||||
.add(docForClassification(indexName, "blue", true, 0.0)) // #4
|
||||
.add(docForClassification(indexName, "blue", true, 0.0)) // #5
|
||||
.add(docForClassification(indexName, "green", true, 0.0)) // #6
|
||||
.add(docForClassification(indexName, "green", true, 0.0)) // #7
|
||||
.add(docForClassification(indexName, "green", true, 0.0)) // #8
|
||||
.add(docForClassification(indexName, "green", true, 1.0)); // #9
|
||||
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #0
|
||||
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #1
|
||||
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #2
|
||||
.add(docForSoftClassification(indexName, "blue", true, 1.0)) // #3
|
||||
.add(docForSoftClassification(indexName, "blue", true, 0.0)) // #4
|
||||
.add(docForSoftClassification(indexName, "blue", true, 0.0)) // #5
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.0)) // #6
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.0)) // #7
|
||||
.add(docForSoftClassification(indexName, "green", true, 0.0)) // #8
|
||||
.add(docForSoftClassification(indexName, "green", true, 1.0)); // #9
|
||||
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
@ -1787,6 +1789,85 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame_Classification() throws IOException {
|
||||
String indexName = "evaluate-classification-test-index";
|
||||
createIndex(indexName, mappingForClassification());
|
||||
BulkRequest regressionBulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, "cat", "cat"))
|
||||
.add(docForClassification(indexName, "cat", "cat"))
|
||||
.add(docForClassification(indexName, "cat", "cat"))
|
||||
.add(docForClassification(indexName, "cat", "dog"))
|
||||
.add(docForClassification(indexName, "cat", "fish"))
|
||||
.add(docForClassification(indexName, "dog", "cat"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "horse", "cat"));
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
||||
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
MulticlassConfusionMatrixMetric.Result mcmResult =
|
||||
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
|
||||
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("horse", 0L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 1L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("dog").put("horse", 0L);
|
||||
expectedConfusionMatrix.put("horse", new HashMap<>());
|
||||
expectedConfusionMatrix.get("horse").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("horse").put("dog", 0L);
|
||||
expectedConfusionMatrix.get("horse").put("horse", 0L);
|
||||
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
|
||||
}
|
||||
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
MulticlassConfusionMatrixMetric.Result mcmResult =
|
||||
evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
|
||||
assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 1L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 3L);
|
||||
assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
|
||||
}
|
||||
}
|
||||
|
||||
private static XContentBuilder defaultMappingForTest() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
@ -1804,7 +1885,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
private static final String actualField = "label";
|
||||
private static final String probabilityField = "p";
|
||||
|
||||
private static XContentBuilder mappingForClassification() throws IOException {
|
||||
private static XContentBuilder mappingForSoftClassification() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject(datasetField)
|
||||
@ -1820,26 +1901,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
|
||||
private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) {
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
|
||||
}
|
||||
|
||||
private static final String actualClassField = "actual_class";
|
||||
private static final String predictedClassField = "predicted_class";
|
||||
|
||||
private static XContentBuilder mappingForClassification() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject(actualClassField)
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.startObject(predictedClassField)
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
|
||||
}
|
||||
|
||||
private static final String actualRegression = "regression_actual";
|
||||
private static final String probabilityRegression = "regression_prob";
|
||||
|
||||
private static XContentBuilder mappingForRegression() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject(actualRegression)
|
||||
.field("type", "double")
|
||||
.startObject(actualRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.startObject(probabilityRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.startObject(probabilityRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForRegression(String indexName, double act, double p) {
|
||||
@ -1854,11 +1957,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
|
||||
public void testEstimateMemoryUsage() throws IOException {
|
||||
String indexName = "estimate-test-index";
|
||||
createIndex(indexName, mappingForClassification());
|
||||
createIndex(indexName, mappingForSoftClassification());
|
||||
BulkRequest bulk1 = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
}
|
||||
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
|
||||
|
||||
@ -1884,7 +1987,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
BulkRequest bulk2 = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 10; i < 100; ++i) {
|
||||
bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
||||
}
|
||||
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
|
||||
|
||||
|
@ -57,7 +57,9 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
|
||||
import org.elasticsearch.client.indexlifecycle.UnfollowAction;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
@ -684,7 +686,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(45, namedXContents.size());
|
||||
assertEquals(48, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
@ -724,22 +726,24 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName()));
|
||||
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
|
||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
MulticlassConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
MulticlassConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
|
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class ClassificationTests extends AbstractXContentTestCase<Classification> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
return new Classification(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||
return Classification.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
}
|
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric.Result> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric.Result createTestInstance() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
Map<String, Long> row = new TreeMap<>();
|
||||
confusionMatrix.put(classNames.get(i), row);
|
||||
for (int j = 0; j < numClasses; j++) {
|
||||
if (randomBoolean()) {
|
||||
row.put(classNames.get(i), randomNonNegativeLong());
|
||||
}
|
||||
}
|
||||
}
|
||||
long otherClassesCount = randomNonNegativeLong();
|
||||
return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrixMetric.Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric createTestInstance() {
|
||||
Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null;
|
||||
return new MulticlassConfusionMatrixMetric(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrixMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -139,6 +139,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
|
||||
@ -469,6 +471,14 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
|
||||
// ML - Data frame evaluation
|
||||
new NamedWriteableRegistry.Entry(
|
||||
Evaluation.class,
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification.NAME.getPreferredName(),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification::new),
|
||||
new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.Result::new),
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new),
|
||||
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),
|
||||
|
@ -8,7 +8,10 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
|
||||
@ -32,6 +35,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
// Evaluations
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
|
||||
BinarySoftClassification::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
|
||||
|
||||
// Soft classification metrics
|
||||
@ -41,6 +45,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
|
||||
ConfusionMatrix::fromXContent));
|
||||
|
||||
// Classification metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
|
||||
MulticlassConfusionMatrix::fromXContent));
|
||||
|
||||
// Regression metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
|
||||
@ -54,6 +62,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
// Evaluations
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
|
||||
Classification::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
|
||||
|
||||
// Evaluation Metrics
|
||||
@ -65,6 +75,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
Recall::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError::new));
|
||||
@ -79,6 +92,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
ScoreByThresholdResult::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError.Result::new));
|
||||
|
@ -0,0 +1,173 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Evaluation of classification results.
|
||||
*/
|
||||
public class Classification implements Evaluation {
|
||||
|
||||
public static final ParseField NAME = new ParseField("classification");
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<ClassificationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS);
|
||||
}
|
||||
|
||||
public static Classification fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<ClassificationMetric> metrics;
|
||||
|
||||
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.metrics = initMetrics(metrics);
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedField = in.readString();
|
||||
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
|
||||
}
|
||||
|
||||
private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
|
||||
List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private static List<ClassificationMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MulticlassConfusionMatrix());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
||||
ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
||||
SearchSourceBuilder searchSourceBuilder =
|
||||
newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder);
|
||||
for (ClassificationMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(SearchResponse searchResponse) {
|
||||
ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
|
||||
}
|
||||
for (ClassificationMetric metric : metrics) {
|
||||
metric.process(searchResponse.getAggregations());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualField);
|
||||
out.writeString(predictedField);
|
||||
out.writeNamedWriteableList(metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (ClassificationMetric metric : metrics) {
|
||||
builder.field(metric.getWriteableName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Classification that = (Classification) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ClassificationMetric extends EvaluationMetric {
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
* @param aggs aggregations from {@link SearchResponse}
|
||||
*/
|
||||
void process(Aggregations aggs);
|
||||
}
|
@ -0,0 +1,277 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.BucketOrder;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.TreeMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* {@link MulticlassConfusionMatrix} is a metric that answers the question:
|
||||
* "How many examples belonging to class X were classified as Y by the classifier?"
|
||||
* for all the possible class pairs {X, Y}.
|
||||
*/
|
||||
public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
|
||||
|
||||
public static final ParseField SIZE = new ParseField("size");
|
||||
|
||||
private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
|
||||
|
||||
private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
|
||||
ConstructingObjectParser<MulticlassConfusionMatrix, Void> parser =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
|
||||
parser.declareInt(optionalConstructorArg(), SIZE);
|
||||
return parser;
|
||||
}
|
||||
|
||||
public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
|
||||
private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
|
||||
private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
|
||||
private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
|
||||
private static final String OTHER_BUCKET_KEY = "_other_";
|
||||
private static final int DEFAULT_SIZE = 10;
|
||||
private static final int MAX_SIZE = 1000;
|
||||
|
||||
private final int size;
|
||||
private List<String> topActualClassNames;
|
||||
private Result result;
|
||||
|
||||
public MulticlassConfusionMatrix() {
|
||||
this((Integer) null);
|
||||
}
|
||||
|
||||
public MulticlassConfusionMatrix(@Nullable Integer size) {
|
||||
if (size != null && (size <= 0 || size > MAX_SIZE)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
|
||||
}
|
||||
this.size = size != null ? size : DEFAULT_SIZE;
|
||||
}
|
||||
|
||||
public MulticlassConfusionMatrix(StreamInput in) throws IOException {
|
||||
this.size = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public int getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
if (topActualClassNames == null) { // This is step 1
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size));
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
KeyedFilter[] keyedFilters =
|
||||
topActualClassNames.stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
|
||||
.field(actualField),
|
||||
AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size)
|
||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters)
|
||||
.otherBucket(true)
|
||||
.otherBucketKey(OTHER_BUCKET_KEY)));
|
||||
}
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
||||
Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList());
|
||||
}
|
||||
if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
|
||||
Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
|
||||
Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
|
||||
Map<String, Map<String, Long>> counts = new TreeMap<>();
|
||||
for (Terms.Bucket bucket : termsAgg.getBuckets()) {
|
||||
String actualClass = bucket.getKeyAsString();
|
||||
Map<String, Long> subCounts = new TreeMap<>();
|
||||
counts.put(actualClass, subCounts);
|
||||
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
|
||||
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
|
||||
String predictedClass = subBucket.getKeyAsString();
|
||||
Long docCount = subBucket.getDocCount();
|
||||
if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) {
|
||||
subCounts.put(predictedClass, docCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SIZE.getPreferredName(), size);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
|
||||
return Objects.equals(this.size, that.size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(size);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
|
||||
private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(
|
||||
constructorArg(),
|
||||
(p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
|
||||
CONFUSION_MATRIX);
|
||||
PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
// Immutable
|
||||
private final Map<String, Map<String, Long>> confusionMatrix;
|
||||
private final long otherClassesCount;
|
||||
|
||||
public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
|
||||
this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
|
||||
this.otherClassesCount = otherClassesCount;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.confusionMatrix = Collections.unmodifiableMap(
|
||||
in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong)));
|
||||
this.otherClassesCount = in.readLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public Map<String, Map<String, Long>> getConfusionMatrix() {
|
||||
return confusionMatrix;
|
||||
}
|
||||
|
||||
public long getOtherClassesCount() {
|
||||
return otherClassesCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeMap(
|
||||
confusionMatrix,
|
||||
StreamOutput::writeString,
|
||||
(out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong));
|
||||
out.writeLong(otherClassesCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
|
||||
builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
|
||||
&& this.otherClassesCount == that.otherClassesCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(confusionMatrix, otherClassesCount);
|
||||
}
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
|
||||
@ -29,6 +30,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
|
||||
String evaluationName = randomAlphaOfLength(10);
|
||||
List<EvaluationMetricResult> metrics =
|
||||
Arrays.asList(
|
||||
MulticlassConfusionMatrixResultTests.createRandom(),
|
||||
new MeanSquaredError.Result(randomDouble()),
|
||||
new RSquared.Result(randomDouble()));
|
||||
int numMetrics = randomIntBetween(0, metrics.size());
|
||||
|
@ -0,0 +1,222 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
return new Classification(
|
||||
randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification doParseInstance(XContentParser parser) throws IOException {
|
||||
return Classification.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Classification createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Classification> instanceReader() {
|
||||
return Classification::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEmptyMetrics() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", "bar", Collections.emptyList()));
|
||||
assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics"));
|
||||
}
|
||||
|
||||
public void testBuildSearch() {
|
||||
QueryBuilder userProvidedQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value"));
|
||||
QueryBuilder expectedSearchQuery =
|
||||
QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery("act"))
|
||||
.filter(QueryBuilders.existsQuery("pred"))
|
||||
.filter(QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.termQuery("field_A", "some-value"))
|
||||
.filter(QueryBuilders.termQuery("field_B", "some-other-value")));
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
|
||||
|
||||
SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
|
||||
assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
|
||||
assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
|
||||
}
|
||||
|
||||
public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() {
|
||||
ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
|
||||
ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
|
||||
ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
|
||||
ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
|
||||
assertThat(metric1.getResult(), isEmpty());
|
||||
assertThat(metric2.getResult(), isEmpty());
|
||||
assertThat(metric3.getResult(), isEmpty());
|
||||
assertThat(metric4.getResult(), isEmpty());
|
||||
assertThat(evaluation.hasAllResults(), is(false));
|
||||
|
||||
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
|
||||
assertThat(metric1.getResult(), isEmpty());
|
||||
assertThat(metric2.getResult(), isEmpty());
|
||||
assertThat(metric3.getResult(), isEmpty());
|
||||
assertThat(metric4.getResult(), isEmpty());
|
||||
assertThat(evaluation.hasAllResults(), is(false));
|
||||
|
||||
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
|
||||
assertThat(metric1.getResult(), isPresent());
|
||||
assertThat(metric2.getResult(), isEmpty());
|
||||
assertThat(metric3.getResult(), isEmpty());
|
||||
assertThat(metric4.getResult(), isEmpty());
|
||||
assertThat(evaluation.hasAllResults(), is(false));
|
||||
|
||||
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
|
||||
assertThat(metric1.getResult(), isPresent());
|
||||
assertThat(metric2.getResult(), isPresent());
|
||||
assertThat(metric3.getResult(), isEmpty());
|
||||
assertThat(metric4.getResult(), isEmpty());
|
||||
assertThat(evaluation.hasAllResults(), is(false));
|
||||
|
||||
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
|
||||
assertThat(metric1.getResult(), isPresent());
|
||||
assertThat(metric2.getResult(), isPresent());
|
||||
assertThat(metric3.getResult(), isPresent());
|
||||
assertThat(metric4.getResult(), isEmpty());
|
||||
assertThat(evaluation.hasAllResults(), is(false));
|
||||
|
||||
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
|
||||
assertThat(metric1.getResult(), isPresent());
|
||||
assertThat(metric2.getResult(), isPresent());
|
||||
assertThat(metric3.getResult(), isPresent());
|
||||
assertThat(metric4.getResult(), isPresent());
|
||||
assertThat(evaluation.hasAllResults(), is(true));
|
||||
|
||||
evaluation.process(mockSearchResponseWithNonZeroTotalHits());
|
||||
assertThat(metric1.getResult(), isPresent());
|
||||
assertThat(metric2.getResult(), isPresent());
|
||||
assertThat(metric3.getResult(), isPresent());
|
||||
assertThat(metric4.getResult(), isPresent());
|
||||
assertThat(evaluation.hasAllResults(), is(true));
|
||||
}
|
||||
|
||||
private static SearchResponse mockSearchResponseWithNonZeroTotalHits() {
|
||||
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||
SearchHits hits = new SearchHits(SearchHits.EMPTY, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0);
|
||||
when(searchResponse.getHits()).thenReturn(hits);
|
||||
return searchResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Metric which iterates through its steps in {@link #process} method.
|
||||
* Number of steps is configurable.
|
||||
* Upon reaching the last step, the result is produced.
|
||||
*/
|
||||
private static class FakeClassificationMetric implements ClassificationMetric {
|
||||
|
||||
private final String name;
|
||||
private final int numSteps;
|
||||
private int currentStepIndex;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
FakeClassificationMetric(String name, int numSteps) {
|
||||
this.name = name;
|
||||
this.numSteps = numSteps;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result != null) {
|
||||
return;
|
||||
}
|
||||
currentStepIndex++;
|
||||
if (currentStepIndex == numSteps) {
|
||||
// This is the last step, time to write evaluation result
|
||||
result = mock(EvaluationMetricResult.class);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) {
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix.Result> {
|
||||
|
||||
public static MulticlassConfusionMatrix.Result createRandom() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
Map<String, Long> row = new TreeMap<>();
|
||||
confusionMatrix.put(classNames.get(i), row);
|
||||
for (int j = 0; j < numClasses; j++) {
|
||||
if (randomBoolean()) {
|
||||
row.put(classNames.get(i), randomNonNegativeLong());
|
||||
}
|
||||
}
|
||||
}
|
||||
long otherClassesCount = randomNonNegativeLong();
|
||||
return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrix.Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix.Result createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<MulticlassConfusionMatrix.Result> instanceReader() {
|
||||
return MulticlassConfusionMatrix.Result::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
}
|
@ -0,0 +1,205 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.mockito.Mockito.doReturn;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
|
||||
return MulticlassConfusionMatrix.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MulticlassConfusionMatrix createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<MulticlassConfusionMatrix> instanceReader() {
|
||||
return MulticlassConfusionMatrix::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static MulticlassConfusionMatrix createRandom() {
|
||||
Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
|
||||
return new MulticlassConfusionMatrix(size);
|
||||
}
|
||||
|
||||
public void testConstructor_SizeValidationFailures() {
|
||||
{
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
|
||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||
}
|
||||
{
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
|
||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||
}
|
||||
{
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
|
||||
assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testAggs() {
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
||||
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
|
||||
assertThat(aggs, is(not(empty())));
|
||||
assertThat(confusionMatrix.getResult(), equalTo(Optional.empty()));
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
"multiclass_confusion_matrix_step_1_by_actual_class",
|
||||
Arrays.asList(
|
||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
0L),
|
||||
mockTerms(
|
||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
||||
Arrays.asList(
|
||||
mockTermsBucket(
|
||||
"dog",
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockTermsBucket(
|
||||
"cat",
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))),
|
||||
0L),
|
||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 10L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 20L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 30L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 40L);
|
||||
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(result.getOtherClassesCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_OtherClassesCountGreaterThanZero() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(
|
||||
"multiclass_confusion_matrix_step_1_by_actual_class",
|
||||
Arrays.asList(
|
||||
mockTermsBucket("dog", new Aggregations(Collections.emptyList())),
|
||||
mockTermsBucket("cat", new Aggregations(Collections.emptyList()))),
|
||||
100L),
|
||||
mockTerms(
|
||||
"multiclass_confusion_matrix_step_2_by_actual_class",
|
||||
Arrays.asList(
|
||||
mockTermsBucket(
|
||||
"dog",
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
|
||||
mockTermsBucket(
|
||||
"cat",
|
||||
new Aggregations(Arrays.asList(mockFilters(
|
||||
"multiclass_confusion_matrix_step_2_by_predicted_class",
|
||||
Arrays.asList(
|
||||
mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))),
|
||||
100L),
|
||||
mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
|
||||
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("cat", 10L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 20L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("cat", 30L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 40L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 15L);
|
||||
assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(result.getOtherClassesCount(), equalTo(3L));
|
||||
}
|
||||
|
||||
private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
|
||||
Terms aggregation = mock(Terms.class);
|
||||
when(aggregation.getName()).thenReturn(name);
|
||||
doReturn(buckets).when(aggregation).getBuckets();
|
||||
when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
|
||||
return aggregation;
|
||||
}
|
||||
|
||||
private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) {
|
||||
Terms.Bucket bucket = mock(Terms.Bucket.class);
|
||||
when(bucket.getKeyAsString()).thenReturn(actualClass);
|
||||
when(bucket.getAggregations()).thenReturn(subAggs);
|
||||
return bucket;
|
||||
}
|
||||
|
||||
private static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
|
||||
Filters aggregation = mock(Filters.class);
|
||||
when(aggregation.getName()).thenReturn(name);
|
||||
doReturn(buckets).when(aggregation).getBuckets();
|
||||
return aggregation;
|
||||
}
|
||||
|
||||
private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) {
|
||||
Filters.Bucket bucket = mock(Filters.Bucket.class);
|
||||
when(bucket.getKeyAsString()).thenReturn(predictedClass);
|
||||
when(bucket.getDocCount()).thenReturn(docCount);
|
||||
return bucket;
|
||||
}
|
||||
|
||||
private static Cardinality mockCardinality(String name, long value) {
|
||||
Cardinality aggregation = mock(Cardinality.class);
|
||||
when(aggregation.getName()).thenReturn(name);
|
||||
when(aggregation.getValue()).thenReturn(value);
|
||||
return aggregation;
|
||||
}
|
||||
}
|
@ -100,6 +100,9 @@ integTest.runner {
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test classification given evaluation with empty metrics',
|
||||
'ml/evaluate_data_frame/Test classification given missing actual_field',
|
||||
'ml/evaluate_data_frame/Test classification given missing predicted_field',
|
||||
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
|
||||
'ml/evaluate_data_frame/Test regression given missing actual_field',
|
||||
'ml/evaluate_data_frame/Test regression given missing predicted_field',
|
||||
|
@ -0,0 +1,200 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
|
||||
|
||||
private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
|
||||
private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
indexAnimalsData(ANIMALS_DATA_INDEX);
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameAction.Request()
|
||||
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
|
||||
.setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
|
||||
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("ant", new HashMap<>());
|
||||
expectedConfusionMatrix.get("ant").put("ant", 1L);
|
||||
expectedConfusionMatrix.get("ant").put("cat", 4L);
|
||||
expectedConfusionMatrix.get("ant").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("ant").put("fox", 2L);
|
||||
expectedConfusionMatrix.get("ant").put("mouse", 5L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("ant", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 5L);
|
||||
expectedConfusionMatrix.get("cat").put("fox", 4L);
|
||||
expectedConfusionMatrix.get("cat").put("mouse", 2L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("ant", 4L);
|
||||
expectedConfusionMatrix.get("dog").put("cat", 2L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("fox", 5L);
|
||||
expectedConfusionMatrix.get("dog").put("mouse", 3L);
|
||||
expectedConfusionMatrix.put("fox", new HashMap<>());
|
||||
expectedConfusionMatrix.get("fox").put("ant", 5L);
|
||||
expectedConfusionMatrix.get("fox").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("fox").put("dog", 2L);
|
||||
expectedConfusionMatrix.get("fox").put("fox", 1L);
|
||||
expectedConfusionMatrix.get("fox").put("mouse", 4L);
|
||||
expectedConfusionMatrix.put("mouse", new HashMap<>());
|
||||
expectedConfusionMatrix.get("mouse").put("ant", 2L);
|
||||
expectedConfusionMatrix.get("mouse").put("cat", 5L);
|
||||
expectedConfusionMatrix.get("mouse").put("dog", 4L);
|
||||
expectedConfusionMatrix.get("mouse").put("fox", 3L);
|
||||
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
|
||||
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
|
||||
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameAction.Request()
|
||||
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
|
||||
.setEvaluation(
|
||||
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix())));
|
||||
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("ant", new HashMap<>());
|
||||
expectedConfusionMatrix.get("ant").put("ant", 1L);
|
||||
expectedConfusionMatrix.get("ant").put("cat", 4L);
|
||||
expectedConfusionMatrix.get("ant").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("ant").put("fox", 2L);
|
||||
expectedConfusionMatrix.get("ant").put("mouse", 5L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("ant", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 5L);
|
||||
expectedConfusionMatrix.get("cat").put("fox", 4L);
|
||||
expectedConfusionMatrix.get("cat").put("mouse", 2L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("ant", 4L);
|
||||
expectedConfusionMatrix.get("dog").put("cat", 2L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("fox", 5L);
|
||||
expectedConfusionMatrix.get("dog").put("mouse", 3L);
|
||||
expectedConfusionMatrix.put("fox", new HashMap<>());
|
||||
expectedConfusionMatrix.get("fox").put("ant", 5L);
|
||||
expectedConfusionMatrix.get("fox").put("cat", 3L);
|
||||
expectedConfusionMatrix.get("fox").put("dog", 2L);
|
||||
expectedConfusionMatrix.get("fox").put("fox", 1L);
|
||||
expectedConfusionMatrix.get("fox").put("mouse", 4L);
|
||||
expectedConfusionMatrix.put("mouse", new HashMap<>());
|
||||
expectedConfusionMatrix.get("mouse").put("ant", 2L);
|
||||
expectedConfusionMatrix.get("mouse").put("cat", 5L);
|
||||
expectedConfusionMatrix.get("mouse").put("dog", 4L);
|
||||
expectedConfusionMatrix.get("mouse").put("fox", 3L);
|
||||
expectedConfusionMatrix.get("mouse").put("mouse", 1L);
|
||||
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
|
||||
EvaluateDataFrameAction.Request evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameAction.Request()
|
||||
.setIndices(Arrays.asList(ANIMALS_DATA_INDEX))
|
||||
.setEvaluation(
|
||||
new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3))));
|
||||
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
MulticlassConfusionMatrix.Result confusionMatrixResult =
|
||||
(MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
|
||||
Map<String, Map<String, Long>> expectedConfusionMatrix = new HashMap<>();
|
||||
expectedConfusionMatrix.put("ant", new HashMap<>());
|
||||
expectedConfusionMatrix.get("ant").put("ant", 1L);
|
||||
expectedConfusionMatrix.get("ant").put("cat", 4L);
|
||||
expectedConfusionMatrix.get("ant").put("dog", 3L);
|
||||
expectedConfusionMatrix.get("ant").put("_other_", 7L);
|
||||
expectedConfusionMatrix.put("cat", new HashMap<>());
|
||||
expectedConfusionMatrix.get("cat").put("ant", 3L);
|
||||
expectedConfusionMatrix.get("cat").put("cat", 1L);
|
||||
expectedConfusionMatrix.get("cat").put("dog", 5L);
|
||||
expectedConfusionMatrix.get("cat").put("_other_", 6L);
|
||||
expectedConfusionMatrix.put("dog", new HashMap<>());
|
||||
expectedConfusionMatrix.get("dog").put("ant", 4L);
|
||||
expectedConfusionMatrix.get("dog").put("cat", 2L);
|
||||
expectedConfusionMatrix.get("dog").put("dog", 1L);
|
||||
expectedConfusionMatrix.get("dog").put("_other_", 8L);
|
||||
assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix));
|
||||
assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L));
|
||||
}
|
||||
|
||||
private static void indexAnimalsData(String indexName) {
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
|
||||
.get();
|
||||
|
||||
List<String> classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox");
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < classNames.size(); i++) {
|
||||
for (int j = 0; j < classNames.size(); j++) {
|
||||
for (int k = 0; k < j + 1; k++) {
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName)
|
||||
.source(
|
||||
ACTUAL_CLASS_FIELD, classNames.get(i),
|
||||
PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
|
||||
}
|
||||
}
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
}
|
@ -11,6 +11,8 @@ setup:
|
||||
"outlier_score": 0.0,
|
||||
"regression_field_act": 10.9,
|
||||
"regression_field_pred": 10.9,
|
||||
"classification_field_act": "dog",
|
||||
"classification_field_pred": "dog",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -26,6 +28,8 @@ setup:
|
||||
"outlier_score": 0.2,
|
||||
"regression_field_act": 12.0,
|
||||
"regression_field_pred": 9.9,
|
||||
"classification_field_act": "cat",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -41,6 +45,8 @@ setup:
|
||||
"outlier_score": 0.3,
|
||||
"regression_field_act": 20.9,
|
||||
"regression_field_pred": 5.9,
|
||||
"classification_field_act": "mouse",
|
||||
"classification_field_pred": "mouse",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -56,6 +62,8 @@ setup:
|
||||
"outlier_score": 0.3,
|
||||
"regression_field_act": 11.9,
|
||||
"regression_field_pred": 11.9,
|
||||
"classification_field_act": "dog",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -71,6 +79,8 @@ setup:
|
||||
"outlier_score": 0.4,
|
||||
"regression_field_act": 42.9,
|
||||
"regression_field_pred": 42.9,
|
||||
"classification_field_act": "cat",
|
||||
"classification_field_pred": "dog",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -86,6 +96,8 @@ setup:
|
||||
"outlier_score": 0.5,
|
||||
"regression_field_act": 0.42,
|
||||
"regression_field_pred": 0.42,
|
||||
"classification_field_act": "dog",
|
||||
"classification_field_pred": "dog",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -101,6 +113,8 @@ setup:
|
||||
"outlier_score": 0.9,
|
||||
"regression_field_act": 1.1235813,
|
||||
"regression_field_pred": 1.12358,
|
||||
"classification_field_act": "cat",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -116,6 +130,8 @@ setup:
|
||||
"outlier_score": 0.95,
|
||||
"regression_field_act": -5.20,
|
||||
"regression_field_pred": -5.1,
|
||||
"classification_field_act": "mouse",
|
||||
"classification_field_pred": "cat",
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
@ -569,6 +585,108 @@ setup:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test classification given evaluation with empty metrics":
|
||||
- do:
|
||||
catch: /\[classification\] must have one or more metrics/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { }
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test classification multiclass_confusion_matrix":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { "multiclass_confusion_matrix": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
|
||||
- match: { classification.multiclass_confusion_matrix._other_: 0 }
|
||||
---
|
||||
"Test classification multiclass_confusion_matrix with explicit size":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { "multiclass_confusion_matrix": { "size": 2 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } }
|
||||
- match: { classification.multiclass_confusion_matrix._other_: 1 }
|
||||
---
|
||||
"Test classification with null metrics":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
|
||||
- match: { classification.multiclass_confusion_matrix._other_: 0 }
|
||||
---
|
||||
"Test classification given missing actual_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "missing",
|
||||
"predicted_field": "classification_field_pred.keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test classification given missing predicted_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "missing"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test regression given evaluation with empty metrics":
|
||||
- do:
|
||||
|
Loading…
x
Reference in New Issue
Block a user