diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index a28c498b1d5..dca644b663e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -18,7 +18,9 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -41,6 +43,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluations new NamedXContentRegistry.Entry( Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent), new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent), // Evaluation metrics new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), @@ -48,6 +51,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(MulticlassConfusionMatrixMetric.NAME), + MulticlassConfusionMatrixMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent), new NamedXContentRegistry.Entry( @@ -60,10 +67,14 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent), + EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, + new ParseField(MulticlassConfusionMatrixMetric.NAME), + MulticlassConfusionMatrixMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( - EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); + EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent)); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java new file mode 100644 index 00000000000..d7466fcc023 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java @@ -0,0 +1,132 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of classification results. + */ +public class Classification implements Evaluation { + + public static final String NAME = "classification"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, true, a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS); + } + + public static Classification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Classification(String actualField, String predictedField) { + this(actualField, predictedField, (List)null); + } + + public Classification(String actualField, String predictedField, EvaluationMetric... metrics) { + this(actualField, predictedField, Arrays.asList(metrics)); + } + + public Classification(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = Objects.requireNonNull(actualField); + this.predictedField = Objects.requireNonNull(predictedField); + if (metrics != null) { + metrics.sort(Comparator.comparing(EvaluationMetric::getName)); + } + this.metrics = metrics; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + if (metrics != null) { + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + } + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Classification that = (Classification) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java new file mode 100644 index 00000000000..a8e8545009b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java @@ -0,0 +1,164 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Calculates the multiclass confusion matrix. + */ +public class MulticlassConfusionMatrixMetric implements EvaluationMetric { + + public static final String NAME = "multiclass_confusion_matrix"; + + public static final ParseField SIZE = new ParseField("size"); + + private static final ConstructingObjectParser PARSER = createParser(); + + private static ConstructingObjectParser createParser() { + ConstructingObjectParser parser = + new ConstructingObjectParser<>(NAME, true, args -> new MulticlassConfusionMatrixMetric((Integer) args[0])); + parser.declareInt(optionalConstructorArg(), SIZE); + return parser; + } + + public static MulticlassConfusionMatrixMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Integer size; + + public MulticlassConfusionMatrixMetric() { + this(null); + } + + public MulticlassConfusionMatrixMetric(@Nullable Integer size) { + this.size = size; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (size != null) { + builder.field(SIZE.getPreferredName(), size); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MulticlassConfusionMatrixMetric that = (MulticlassConfusionMatrixMetric) o; + return Objects.equals(this.size, that.size); + } + + @Override + public int hashCode() { + return Objects.hash(size); + } + + public static class Result implements EvaluationMetric.Result { + + private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + + static { + PARSER.declareObject( + constructorArg(), + (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), + CONFUSION_MATRIX); + PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + // Immutable + private final Map> confusionMatrix; + private final long otherClassesCount; + + public Result(Map> confusionMatrix, long otherClassesCount) { + this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); + this.otherClassesCount = otherClassesCount; + } + + @Override + public String getMetricName() { + return NAME; + } + + public Map> getConfusionMatrix() { + return confusionMatrix; + } + + public long getOtherClassesCount() { + return otherClassesCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.confusionMatrix, that.confusionMatrix) + && this.otherClassesCount == that.otherClassesCount; + } + + @Override + public int hashCode() { + return Objects.hash(confusionMatrix, otherClassesCount); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 9c8663d8eb3..a8eff6b9297 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -1638,19 +1640,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { public void testEvaluateDataFrame_BinarySoftClassification() throws IOException { String indexName = "evaluate-test-index"; - createIndex(indexName, mappingForClassification()); + createIndex(indexName, mappingForSoftClassification()); BulkRequest bulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, "blue", false, 0.1)) // #0 - .add(docForClassification(indexName, "blue", false, 0.2)) // #1 - .add(docForClassification(indexName, "blue", false, 0.3)) // #2 - .add(docForClassification(indexName, "blue", false, 0.4)) // #3 - .add(docForClassification(indexName, "blue", false, 0.7)) // #4 - .add(docForClassification(indexName, "blue", true, 0.2)) // #5 - .add(docForClassification(indexName, "green", true, 0.3)) // #6 - .add(docForClassification(indexName, "green", true, 0.4)) // #7 - .add(docForClassification(indexName, "green", true, 0.8)) // #8 - .add(docForClassification(indexName, "green", true, 0.9)); // #9 + .add(docForSoftClassification(indexName, "blue", false, 0.1)) // #0 + .add(docForSoftClassification(indexName, "blue", false, 0.2)) // #1 + .add(docForSoftClassification(indexName, "blue", false, 0.3)) // #2 + .add(docForSoftClassification(indexName, "blue", false, 0.4)) // #3 + .add(docForSoftClassification(indexName, "blue", false, 0.7)) // #4 + .add(docForSoftClassification(indexName, "blue", true, 0.2)) // #5 + .add(docForSoftClassification(indexName, "green", true, 0.3)) // #6 + .add(docForSoftClassification(indexName, "green", true, 0.4)) // #7 + .add(docForSoftClassification(indexName, "green", true, 0.8)) // #8 + .add(docForSoftClassification(indexName, "green", true, 0.9)); // #9 highLevelClient().bulk(bulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1712,19 +1714,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException { String indexName = "evaluate-with-query-test-index"; - createIndex(indexName, mappingForClassification()); + createIndex(indexName, mappingForSoftClassification()); BulkRequest bulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, "blue", true, 1.0)) // #0 - .add(docForClassification(indexName, "blue", true, 1.0)) // #1 - .add(docForClassification(indexName, "blue", true, 1.0)) // #2 - .add(docForClassification(indexName, "blue", true, 1.0)) // #3 - .add(docForClassification(indexName, "blue", true, 0.0)) // #4 - .add(docForClassification(indexName, "blue", true, 0.0)) // #5 - .add(docForClassification(indexName, "green", true, 0.0)) // #6 - .add(docForClassification(indexName, "green", true, 0.0)) // #7 - .add(docForClassification(indexName, "green", true, 0.0)) // #8 - .add(docForClassification(indexName, "green", true, 1.0)); // #9 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #0 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #1 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #2 + .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #3 + .add(docForSoftClassification(indexName, "blue", true, 0.0)) // #4 + .add(docForSoftClassification(indexName, "blue", true, 0.0)) // #5 + .add(docForSoftClassification(indexName, "green", true, 0.0)) // #6 + .add(docForSoftClassification(indexName, "green", true, 0.0)) // #7 + .add(docForSoftClassification(indexName, "green", true, 0.0)) // #8 + .add(docForSoftClassification(indexName, "green", true, 1.0)); // #9 highLevelClient().bulk(bulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1787,6 +1789,85 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9)); } + public void testEvaluateDataFrame_Classification() throws IOException { + String indexName = "evaluate-classification-test-index"; + createIndex(indexName, mappingForClassification()); + BulkRequest regressionBulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForClassification(indexName, "cat", "cat")) + .add(docForClassification(indexName, "cat", "cat")) + .add(docForClassification(indexName, "cat", "cat")) + .add(docForClassification(indexName, "cat", "dog")) + .add(docForClassification(indexName, "cat", "fish")) + .add(docForClassification(indexName, "dog", "cat")) + .add(docForClassification(indexName, "dog", "dog")) + .add(docForClassification(indexName, "dog", "dog")) + .add(docForClassification(indexName, "dog", "dog")) + .add(docForClassification(indexName, "horse", "cat")); + highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + + { // No size provided for MulticlassConfusionMatrixMetric, default used instead + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + null, + new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + MulticlassConfusionMatrixMetric.Result mcmResult = + evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); + assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("cat", 3L); + expectedConfusionMatrix.get("cat").put("dog", 1L); + expectedConfusionMatrix.get("cat").put("horse", 0L); + expectedConfusionMatrix.get("cat").put("_other_", 1L); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("cat", 1L); + expectedConfusionMatrix.get("dog").put("dog", 3L); + expectedConfusionMatrix.get("dog").put("horse", 0L); + expectedConfusionMatrix.put("horse", new HashMap<>()); + expectedConfusionMatrix.get("horse").put("cat", 1L); + expectedConfusionMatrix.get("horse").put("dog", 0L); + expectedConfusionMatrix.get("horse").put("horse", 0L); + assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(mcmResult.getOtherClassesCount(), equalTo(0L)); + } + { // Explicit size provided for MulticlassConfusionMatrixMetric metric + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + null, + new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2))); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + MulticlassConfusionMatrixMetric.Result mcmResult = + evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME); + assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("cat", 3L); + expectedConfusionMatrix.get("cat").put("dog", 1L); + expectedConfusionMatrix.get("cat").put("_other_", 1L); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("cat", 1L); + expectedConfusionMatrix.get("dog").put("dog", 3L); + assertThat(mcmResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(mcmResult.getOtherClassesCount(), equalTo(1L)); + } + } + private static XContentBuilder defaultMappingForTest() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") @@ -1804,7 +1885,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { private static final String actualField = "label"; private static final String probabilityField = "p"; - private static XContentBuilder mappingForClassification() throws IOException { + private static XContentBuilder mappingForSoftClassification() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") .startObject(datasetField) @@ -1820,26 +1901,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .endObject(); } - private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) { + private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) { return new IndexRequest() .index(indexName) .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p); } + private static final String actualClassField = "actual_class"; + private static final String predictedClassField = "predicted_class"; + + private static XContentBuilder mappingForClassification() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualClassField) + .field("type", "keyword") + .endObject() + .startObject(predictedClassField) + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass); + } + private static final String actualRegression = "regression_actual"; private static final String probabilityRegression = "regression_prob"; private static XContentBuilder mappingForRegression() throws IOException { return XContentFactory.jsonBuilder().startObject() .startObject("properties") - .startObject(actualRegression) - .field("type", "double") + .startObject(actualRegression) + .field("type", "double") + .endObject() + .startObject(probabilityRegression) + .field("type", "double") + .endObject() .endObject() - .startObject(probabilityRegression) - .field("type", "double") - .endObject() - .endObject() - .endObject(); + .endObject(); } private static IndexRequest docForRegression(String indexName, double act, double p) { @@ -1854,11 +1957,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { public void testEstimateMemoryUsage() throws IOException { String indexName = "estimate-test-index"; - createIndex(indexName, mappingForClassification()); + createIndex(indexName, mappingForSoftClassification()); BulkRequest bulk1 = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < 10; ++i) { - bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); + bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); } highLevelClient().bulk(bulk1, RequestOptions.DEFAULT); @@ -1884,7 +1987,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { BulkRequest bulk2 = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 10; i < 100; ++i) { - bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); + bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true))); } highLevelClient().bulk(bulk2, RequestOptions.DEFAULT); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 31adbff18a9..59df6ea93a0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -57,7 +57,9 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -684,7 +686,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(45, namedXContents.size()); + assertEquals(48, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -724,22 +726,24 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Classification.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); - assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); - assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME)); + assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, + MulticlassConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME, RSquaredMetric.NAME)); - assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, + MulticlassConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME, RSquaredMetric.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java new file mode 100644 index 00000000000..a72b483518c --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.function.Predicate; + +public class ClassificationTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Classification createRandom() { + return new Classification( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric())); + } + + @Override + protected Classification createTestInstance() { + return createRandom(); + } + + @Override + protected Classification doParseInstance(XContentParser parser) throws IOException { + return Classification.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java new file mode 100644 index 00000000000..800a2cf7b98 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MulticlassConfusionMatrixMetric.Result createTestInstance() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + Map> confusionMatrix = new TreeMap<>(); + for (int i = 0; i < numClasses; i++) { + Map row = new TreeMap<>(); + confusionMatrix.put(classNames.get(i), row); + for (int j = 0; j < numClasses; j++) { + if (randomBoolean()) { + row.put(classNames.get(i), randomNonNegativeLong()); + } + } + } + long otherClassesCount = randomNonNegativeLong(); + return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount); + } + + @Override + protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrixMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java new file mode 100644 index 00000000000..f4de12796f0 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MulticlassConfusionMatrixMetric createTestInstance() { + Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null; + return new MulticlassConfusionMatrixMetric(size); + } + + @Override + protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrixMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 590914fd93d..03b8e2bc49f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -139,6 +139,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; @@ -469,6 +471,14 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new), new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new), // ML - Data frame evaluation + new NamedWriteableRegistry.Entry( + Evaluation.class, + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification.NAME.getPreferredName(), + org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification::new), + new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(), + MulticlassConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(), + MulticlassConfusionMatrix.Result::new), new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new), new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index a2aa8e74918..8036c5ab895 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -8,7 +8,10 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric; @@ -32,6 +35,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluations namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent)); // Soft classification metrics @@ -41,6 +45,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, ConfusionMatrix::fromXContent)); + // Classification metrics + namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME, + MulticlassConfusionMatrix::fromXContent)); + // Regression metrics namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent)); @@ -54,6 +62,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluations namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(), + Classification::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new)); // Evaluation Metrics @@ -65,6 +75,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider Recall::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, + MulticlassConfusionMatrix.NAME.getPreferredName(), + MulticlassConfusionMatrix::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError::new)); @@ -79,6 +92,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider ScoreByThresholdResult::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + MulticlassConfusionMatrix.NAME.getPreferredName(), + MulticlassConfusionMatrix.Result::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError.Result::new)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java new file mode 100644 index 00000000000..b4a9e6f09a2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of classification results. + */ +public class Classification implements Evaluation { + + public static final ParseField NAME = new ParseField("classification"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS); + } + + public static Classification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Classification(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); + this.metrics = initMetrics(metrics); + } + + public Classification(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedField = in.readString(); + this.metrics = in.readNamedWriteableList(ClassificationMetric.class); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics); + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName)); + return metrics; + } + + private static List defaultMetrics() { + return Arrays.asList(new MulticlassConfusionMatrix()); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public List getMetrics() { + return metrics; + } + + @Override + public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) { + ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder"); + SearchSourceBuilder searchSourceBuilder = + newSearchSourceBuilder(Arrays.asList(actualField, predictedField), userProvidedQueryBuilder); + for (ClassificationMetric metric : metrics) { + List aggs = metric.aggs(actualField, predictedField); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + @Override + public void process(SearchResponse searchResponse) { + ExceptionsHelper.requireNonNull(searchResponse, "searchResponse"); + if (searchResponse.getHits().getTotalHits().value == 0) { + throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField); + } + for (ClassificationMetric metric : metrics) { + metric.process(searchResponse.getAggregations()); + } + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + builder.startObject(METRICS.getPreferredName()); + for (ClassificationMetric metric : metrics) { + builder.field(metric.getWriteableName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Classification that = (Classification) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java new file mode 100644 index 00000000000..220942a4838 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; + +import java.util.List; + +public interface ClassificationMetric extends EvaluationMetric { + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual value + * @param predictedField the field that stores the predicted value + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, String predictedField); + + /** + * Processes given aggregations as a step towards computing result + * @param aggs aggregations from {@link SearchResponse} + */ + void process(Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java new file mode 100644 index 00000000000..a8b24a34447 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java @@ -0,0 +1,277 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.TreeMap; +import java.util.stream.Collectors; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * {@link MulticlassConfusionMatrix} is a metric that answers the question: + * "How many examples belonging to class X were classified as Y by the classifier?" + * for all the possible class pairs {X, Y}. + */ +public class MulticlassConfusionMatrix implements ClassificationMetric { + + public static final ParseField NAME = new ParseField("multiclass_confusion_matrix"); + + public static final ParseField SIZE = new ParseField("size"); + + private static final ConstructingObjectParser PARSER = createParser(); + + private static ConstructingObjectParser createParser() { + ConstructingObjectParser parser = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0])); + parser.declareInt(optionalConstructorArg(), SIZE); + return parser; + } + + public static MulticlassConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class"; + private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class"; + private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class"; + private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class"; + private static final String OTHER_BUCKET_KEY = "_other_"; + private static final int DEFAULT_SIZE = 10; + private static final int MAX_SIZE = 1000; + + private final int size; + private List topActualClassNames; + private Result result; + + public MulticlassConfusionMatrix() { + this((Integer) null); + } + + public MulticlassConfusionMatrix(@Nullable Integer size) { + if (size != null && (size <= 0 || size > MAX_SIZE)) { + throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE); + } + this.size = size != null ? size : DEFAULT_SIZE; + } + + public MulticlassConfusionMatrix(StreamInput in) throws IOException { + this.size = in.readVInt(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public int getSize() { + return size; + } + + @Override + public final List aggs(String actualField, String predictedField) { + if (topActualClassNames == null) { // This is step 1 + return Arrays.asList( + AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) + .field(actualField) + .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) + .size(size)); + } + if (result == null) { // This is step 2 + KeyedFilter[] keyedFilters = + topActualClassNames.stream() + .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className))) + .toArray(KeyedFilter[]::new); + return Arrays.asList( + AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS) + .field(actualField), + AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) + .field(actualField) + .order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))) + .size(size) + .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters) + .otherBucket(true) + .otherBucketKey(OTHER_BUCKET_KEY))); + } + return Collections.emptyList(); + } + + @Override + public void process(Aggregations aggs) { + if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) { + Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS); + topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList()); + } + if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) { + Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS); + Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS); + Map> counts = new TreeMap<>(); + for (Terms.Bucket bucket : termsAgg.getBuckets()) { + String actualClass = bucket.getKeyAsString(); + Map subCounts = new TreeMap<>(); + counts.put(actualClass, subCounts); + Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS); + for (Filters.Bucket subBucket : subAgg.getBuckets()) { + String predictedClass = subBucket.getKeyAsString(); + Long docCount = subBucket.getDocCount(); + if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) { + subCounts.put(predictedClass, docCount); + } + } + } + result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size); + } + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(size); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(SIZE.getPreferredName(), size); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o; + return Objects.equals(this.size, that.size); + } + + @Override + public int hashCode() { + return Objects.hash(size); + } + + public static class Result implements EvaluationMetricResult { + + private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix"); + private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "multiclass_confusion_matrix_result", true, a -> new Result((Map>) a[0], (long) a[1])); + + static { + PARSER.declareObject( + constructorArg(), + (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)), + CONFUSION_MATRIX); + PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT); + } + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + // Immutable + private final Map> confusionMatrix; + private final long otherClassesCount; + + public Result(Map> confusionMatrix, long otherClassesCount) { + this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix)); + this.otherClassesCount = otherClassesCount; + } + + public Result(StreamInput in) throws IOException { + this.confusionMatrix = Collections.unmodifiableMap( + in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong))); + this.otherClassesCount = in.readLong(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + public Map> getConfusionMatrix() { + return confusionMatrix; + } + + public long getOtherClassesCount() { + return otherClassesCount; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap( + confusionMatrix, + StreamOutput::writeString, + (out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong)); + out.writeLong(otherClassesCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix); + builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(this.confusionMatrix, that.confusionMatrix) + && this.otherClassesCount == that.otherClassesCount; + } + + @Override + public int hashCode() { + return Objects.hash(confusionMatrix, otherClassesCount); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java index 3d095e995a6..c9eb0ae437e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; @@ -29,6 +30,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin String evaluationName = randomAlphaOfLength(10); List metrics = Arrays.asList( + MulticlassConfusionMatrixResultTests.createRandom(), new MeanSquaredError.Result(randomDouble()), new RSquared.Result(randomDouble())); int numMetrics = randomIntBetween(0, metrics.size()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java new file mode 100644 index 00000000000..96cbdf843db --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -0,0 +1,222 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ClassificationTests extends AbstractSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Classification createRandom() { + return new Classification( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom())); + } + + @Override + protected Classification doParseInstance(XContentParser parser) throws IOException { + return Classification.fromXContent(parser); + } + + @Override + protected Classification createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Classification::new; + } + + public void testConstructor_GivenEmptyMetrics() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", "bar", Collections.emptyList())); + assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics")); + } + + public void testBuildSearch() { + QueryBuilder userProvidedQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value")); + QueryBuilder expectedSearchQuery = + QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery("act")) + .filter(QueryBuilders.existsQuery("pred")) + .filter(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("field_A", "some-value")) + .filter(QueryBuilders.termQuery("field_B", "some-other-value"))); + + Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix())); + + SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery); + assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery)); + assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0)); + } + + public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() { + ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2); + ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3); + ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4); + ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5); + + Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4)); + assertThat(metric1.getResult(), isEmpty()); + assertThat(metric2.getResult(), isEmpty()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isEmpty()); + assertThat(metric2.getResult(), isEmpty()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isEmpty()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isEmpty()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isPresent()); + assertThat(metric4.getResult(), isEmpty()); + assertThat(evaluation.hasAllResults(), is(false)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isPresent()); + assertThat(metric4.getResult(), isPresent()); + assertThat(evaluation.hasAllResults(), is(true)); + + evaluation.process(mockSearchResponseWithNonZeroTotalHits()); + assertThat(metric1.getResult(), isPresent()); + assertThat(metric2.getResult(), isPresent()); + assertThat(metric3.getResult(), isPresent()); + assertThat(metric4.getResult(), isPresent()); + assertThat(evaluation.hasAllResults(), is(true)); + } + + private static SearchResponse mockSearchResponseWithNonZeroTotalHits() { + SearchResponse searchResponse = mock(SearchResponse.class); + SearchHits hits = new SearchHits(SearchHits.EMPTY, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0); + when(searchResponse.getHits()).thenReturn(hits); + return searchResponse; + } + + /** + * Metric which iterates through its steps in {@link #process} method. + * Number of steps is configurable. + * Upon reaching the last step, the result is produced. + */ + private static class FakeClassificationMetric implements ClassificationMetric { + + private final String name; + private final int numSteps; + private int currentStepIndex; + private EvaluationMetricResult result; + + FakeClassificationMetric(String name, int numSteps) { + this.name = name; + this.numSteps = numSteps; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getWriteableName() { + return name; + } + + @Override + public List aggs(String actualField, String predictedField) { + return Collections.emptyList(); + } + + @Override + public void process(Aggregations aggs) { + if (result != null) { + return; + } + currentStepIndex++; + if (currentStepIndex == numSteps) { + // This is the last step, time to write evaluation result + result = mock(EvaluationMetricResult.class); + } + } + + @Override + public Optional getResult() { + return Optional.ofNullable(result); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public void writeTo(StreamOutput out) { + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java new file mode 100644 index 00000000000..24b13d372d5 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase { + + public static MulticlassConfusionMatrix.Result createRandom() { + int numClasses = randomIntBetween(2, 100); + List classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList()); + Map> confusionMatrix = new TreeMap<>(); + for (int i = 0; i < numClasses; i++) { + Map row = new TreeMap<>(); + confusionMatrix.put(classNames.get(i), row); + for (int j = 0; j < numClasses; j++) { + if (randomBoolean()) { + row.put(classNames.get(i), randomNonNegativeLong()); + } + } + } + long otherClassesCount = randomNonNegativeLong(); + return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount); + } + + @Override + protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrix.Result.fromXContent(parser); + } + + @Override + protected MulticlassConfusionMatrix.Result createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MulticlassConfusionMatrix.Result::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java new file mode 100644 index 00000000000..ff788460b49 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java @@ -0,0 +1,205 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filters; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase { + + @Override + protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException { + return MulticlassConfusionMatrix.fromXContent(parser); + } + + @Override + protected MulticlassConfusionMatrix createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MulticlassConfusionMatrix::new; + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + public static MulticlassConfusionMatrix createRandom() { + Integer size = randomBoolean() ? null : randomIntBetween(1, 1000); + return new MulticlassConfusionMatrix(size); + } + + public void testConstructor_SizeValidationFailures() { + { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1)); + assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); + } + { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0)); + assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); + } + { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001)); + assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]")); + } + } + + public void testAggs() { + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(); + List aggs = confusionMatrix.aggs("act", "pred"); + assertThat(aggs, is(not(empty()))); + assertThat(confusionMatrix.getResult(), equalTo(Optional.empty())); + } + + public void testEvaluate() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms( + "multiclass_confusion_matrix_step_1_by_actual_class", + Arrays.asList( + mockTermsBucket("dog", new Aggregations(Collections.emptyList())), + mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), + 0L), + mockTerms( + "multiclass_confusion_matrix_step_2_by_actual_class", + Arrays.asList( + mockTermsBucket( + "dog", + new Aggregations(Arrays.asList(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + Arrays.asList( + mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockTermsBucket( + "cat", + new Aggregations(Arrays.asList(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + Arrays.asList( + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))), + 0L), + mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L))); + + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + confusionMatrix.process(aggs); + + assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); + MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("cat", 10L); + expectedConfusionMatrix.get("dog").put("dog", 20L); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("cat", 30L); + expectedConfusionMatrix.get("cat").put("dog", 40L); + assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(result.getOtherClassesCount(), equalTo(0L)); + } + + public void testEvaluate_OtherClassesCountGreaterThanZero() { + Aggregations aggs = new Aggregations(Arrays.asList( + mockTerms( + "multiclass_confusion_matrix_step_1_by_actual_class", + Arrays.asList( + mockTermsBucket("dog", new Aggregations(Collections.emptyList())), + mockTermsBucket("cat", new Aggregations(Collections.emptyList()))), + 100L), + mockTerms( + "multiclass_confusion_matrix_step_2_by_actual_class", + Arrays.asList( + mockTermsBucket( + "dog", + new Aggregations(Arrays.asList(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + Arrays.asList( + mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))), + mockTermsBucket( + "cat", + new Aggregations(Arrays.asList(mockFilters( + "multiclass_confusion_matrix_step_2_by_predicted_class", + Arrays.asList( + mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))), + 100L), + mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L))); + + MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2); + confusionMatrix.process(aggs); + + assertThat(confusionMatrix.aggs("act", "pred"), is(empty())); + MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get(); + assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix")); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("cat", 10L); + expectedConfusionMatrix.get("dog").put("dog", 20L); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("cat", 30L); + expectedConfusionMatrix.get("cat").put("dog", 40L); + expectedConfusionMatrix.get("cat").put("_other_", 15L); + assertThat(result.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(result.getOtherClassesCount(), equalTo(3L)); + } + + private static Terms mockTerms(String name, List buckets, long sumOfOtherDocCounts) { + Terms aggregation = mock(Terms.class); + when(aggregation.getName()).thenReturn(name); + doReturn(buckets).when(aggregation).getBuckets(); + when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts); + return aggregation; + } + + private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) { + Terms.Bucket bucket = mock(Terms.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(actualClass); + when(bucket.getAggregations()).thenReturn(subAggs); + return bucket; + } + + private static Filters mockFilters(String name, List buckets) { + Filters aggregation = mock(Filters.class); + when(aggregation.getName()).thenReturn(name); + doReturn(buckets).when(aggregation).getBuckets(); + return aggregation; + } + + private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) { + Filters.Bucket bucket = mock(Filters.Bucket.class); + when(bucket.getKeyAsString()).thenReturn(predictedClass); + when(bucket.getDocCount()).thenReturn(docCount); + return bucket; + } + + private static Cardinality mockCardinality(String name, long value) { + Cardinality aggregation = mock(Cardinality.class); + when(aggregation.getName()).thenReturn(name); + when(aggregation.getValue()).thenReturn(value); + return aggregation; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 6dfa5798c4b..b0fbfc5cd37 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -100,6 +100,9 @@ integTest.runner { 'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', + 'ml/evaluate_data_frame/Test classification given evaluation with empty metrics', + 'ml/evaluate_data_frame/Test classification given missing actual_field', + 'ml/evaluate_data_frame/Test classification given missing predicted_field', 'ml/evaluate_data_frame/Test regression given evaluation with empty metrics', 'ml/evaluate_data_frame/Test regression given missing actual_field', 'ml/evaluate_data_frame/Test regression given missing predicted_field', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java new file mode 100644 index 00000000000..ba70828f5c1 --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java @@ -0,0 +1,200 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.junit.After; +import org.junit.Before; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; + +public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index"; + + private static final String ACTUAL_CLASS_FIELD = "actual_class_field"; + private static final String PREDICTED_CLASS_FIELD = "predicted_class_field"; + + @Before + public void setup() { + indexAnimalsData(ANIMALS_DATA_INDEX); + } + + @After + public void cleanup() { + cleanUp(); + } + + public void testEvaluate_MulticlassClassification_DefaultMetrics() { + EvaluateDataFrameAction.Request evaluateDataFrameRequest = + new EvaluateDataFrameAction.Request() + .setIndices(Arrays.asList(ANIMALS_DATA_INDEX)) + .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null)); + + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("ant", new HashMap<>()); + expectedConfusionMatrix.get("ant").put("ant", 1L); + expectedConfusionMatrix.get("ant").put("cat", 4L); + expectedConfusionMatrix.get("ant").put("dog", 3L); + expectedConfusionMatrix.get("ant").put("fox", 2L); + expectedConfusionMatrix.get("ant").put("mouse", 5L); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("ant", 3L); + expectedConfusionMatrix.get("cat").put("cat", 1L); + expectedConfusionMatrix.get("cat").put("dog", 5L); + expectedConfusionMatrix.get("cat").put("fox", 4L); + expectedConfusionMatrix.get("cat").put("mouse", 2L); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("ant", 4L); + expectedConfusionMatrix.get("dog").put("cat", 2L); + expectedConfusionMatrix.get("dog").put("dog", 1L); + expectedConfusionMatrix.get("dog").put("fox", 5L); + expectedConfusionMatrix.get("dog").put("mouse", 3L); + expectedConfusionMatrix.put("fox", new HashMap<>()); + expectedConfusionMatrix.get("fox").put("ant", 5L); + expectedConfusionMatrix.get("fox").put("cat", 3L); + expectedConfusionMatrix.get("fox").put("dog", 2L); + expectedConfusionMatrix.get("fox").put("fox", 1L); + expectedConfusionMatrix.get("fox").put("mouse", 4L); + expectedConfusionMatrix.put("mouse", new HashMap<>()); + expectedConfusionMatrix.get("mouse").put("ant", 2L); + expectedConfusionMatrix.get("mouse").put("cat", 5L); + expectedConfusionMatrix.get("mouse").put("dog", 4L); + expectedConfusionMatrix.get("mouse").put("fox", 3L); + expectedConfusionMatrix.get("mouse").put("mouse", 1L); + assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + } + + public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() { + EvaluateDataFrameAction.Request evaluateDataFrameRequest = + new EvaluateDataFrameAction.Request() + .setIndices(Arrays.asList(ANIMALS_DATA_INDEX)) + .setEvaluation( + new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix()))); + + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("ant", new HashMap<>()); + expectedConfusionMatrix.get("ant").put("ant", 1L); + expectedConfusionMatrix.get("ant").put("cat", 4L); + expectedConfusionMatrix.get("ant").put("dog", 3L); + expectedConfusionMatrix.get("ant").put("fox", 2L); + expectedConfusionMatrix.get("ant").put("mouse", 5L); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("ant", 3L); + expectedConfusionMatrix.get("cat").put("cat", 1L); + expectedConfusionMatrix.get("cat").put("dog", 5L); + expectedConfusionMatrix.get("cat").put("fox", 4L); + expectedConfusionMatrix.get("cat").put("mouse", 2L); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("ant", 4L); + expectedConfusionMatrix.get("dog").put("cat", 2L); + expectedConfusionMatrix.get("dog").put("dog", 1L); + expectedConfusionMatrix.get("dog").put("fox", 5L); + expectedConfusionMatrix.get("dog").put("mouse", 3L); + expectedConfusionMatrix.put("fox", new HashMap<>()); + expectedConfusionMatrix.get("fox").put("ant", 5L); + expectedConfusionMatrix.get("fox").put("cat", 3L); + expectedConfusionMatrix.get("fox").put("dog", 2L); + expectedConfusionMatrix.get("fox").put("fox", 1L); + expectedConfusionMatrix.get("fox").put("mouse", 4L); + expectedConfusionMatrix.put("mouse", new HashMap<>()); + expectedConfusionMatrix.get("mouse").put("ant", 2L); + expectedConfusionMatrix.get("mouse").put("cat", 5L); + expectedConfusionMatrix.get("mouse").put("dog", 4L); + expectedConfusionMatrix.get("mouse").put("fox", 3L); + expectedConfusionMatrix.get("mouse").put("mouse", 1L); + assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L)); + } + + public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() { + EvaluateDataFrameAction.Request evaluateDataFrameRequest = + new EvaluateDataFrameAction.Request() + .setIndices(Arrays.asList(ANIMALS_DATA_INDEX)) + .setEvaluation( + new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, Arrays.asList(new MulticlassConfusionMatrix(3)))); + + EvaluateDataFrameAction.Response evaluateDataFrameResponse = + client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet(); + + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName())); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + MulticlassConfusionMatrix.Result confusionMatrixResult = + (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0); + assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName())); + Map> expectedConfusionMatrix = new HashMap<>(); + expectedConfusionMatrix.put("ant", new HashMap<>()); + expectedConfusionMatrix.get("ant").put("ant", 1L); + expectedConfusionMatrix.get("ant").put("cat", 4L); + expectedConfusionMatrix.get("ant").put("dog", 3L); + expectedConfusionMatrix.get("ant").put("_other_", 7L); + expectedConfusionMatrix.put("cat", new HashMap<>()); + expectedConfusionMatrix.get("cat").put("ant", 3L); + expectedConfusionMatrix.get("cat").put("cat", 1L); + expectedConfusionMatrix.get("cat").put("dog", 5L); + expectedConfusionMatrix.get("cat").put("_other_", 6L); + expectedConfusionMatrix.put("dog", new HashMap<>()); + expectedConfusionMatrix.get("dog").put("ant", 4L); + expectedConfusionMatrix.get("dog").put("cat", 2L); + expectedConfusionMatrix.get("dog").put("dog", 1L); + expectedConfusionMatrix.get("dog").put("_other_", 8L); + assertThat(confusionMatrixResult.getConfusionMatrix(), equalTo(expectedConfusionMatrix)); + assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L)); + } + + private static void indexAnimalsData(String indexName) { + client().admin().indices().prepareCreate(indexName) + .addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword") + .get(); + + List classNames = Arrays.asList("dog", "cat", "mouse", "ant", "fox"); + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < classNames.size(); i++) { + for (int j = 0; j < classNames.size(); j++) { + for (int k = 0; k < j + 1; k++) { + bulkRequestBuilder.add( + new IndexRequest(indexName) + .source( + ACTUAL_CLASS_FIELD, classNames.get(i), + PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size()))); + } + } + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 7459e695901..1bcde11f2fb 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -11,6 +11,8 @@ setup: "outlier_score": 0.0, "regression_field_act": 10.9, "regression_field_pred": 10.9, + "classification_field_act": "dog", + "classification_field_pred": "dog", "all_true_field": true, "all_false_field": false } @@ -26,6 +28,8 @@ setup: "outlier_score": 0.2, "regression_field_act": 12.0, "regression_field_pred": 9.9, + "classification_field_act": "cat", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -41,6 +45,8 @@ setup: "outlier_score": 0.3, "regression_field_act": 20.9, "regression_field_pred": 5.9, + "classification_field_act": "mouse", + "classification_field_pred": "mouse", "all_true_field": true, "all_false_field": false } @@ -56,6 +62,8 @@ setup: "outlier_score": 0.3, "regression_field_act": 11.9, "regression_field_pred": 11.9, + "classification_field_act": "dog", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -71,6 +79,8 @@ setup: "outlier_score": 0.4, "regression_field_act": 42.9, "regression_field_pred": 42.9, + "classification_field_act": "cat", + "classification_field_pred": "dog", "all_true_field": true, "all_false_field": false } @@ -86,6 +96,8 @@ setup: "outlier_score": 0.5, "regression_field_act": 0.42, "regression_field_pred": 0.42, + "classification_field_act": "dog", + "classification_field_pred": "dog", "all_true_field": true, "all_false_field": false } @@ -101,6 +113,8 @@ setup: "outlier_score": 0.9, "regression_field_act": 1.1235813, "regression_field_pred": 1.12358, + "classification_field_act": "cat", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -116,6 +130,8 @@ setup: "outlier_score": 0.95, "regression_field_act": -5.20, "regression_field_pred": -5.1, + "classification_field_act": "mouse", + "classification_field_pred": "cat", "all_true_field": true, "all_false_field": false } @@ -569,6 +585,108 @@ setup: } } } + +--- +"Test classification given evaluation with empty metrics": + - do: + catch: /\[classification\] must have one or more metrics/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { } + } + } + } +--- +"Test classification multiclass_confusion_matrix": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "multiclass_confusion_matrix": {} } + } + } + } + + - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } + - match: { classification.multiclass_confusion_matrix._other_: 0 } +--- +"Test classification multiclass_confusion_matrix with explicit size": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword", + "metrics": { "multiclass_confusion_matrix": { "size": 2 } } + } + } + } + + - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } } + - match: { classification.multiclass_confusion_matrix._other_: 1 } +--- +"Test classification with null metrics": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "classification_field_pred.keyword" + } + } + } + + - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } } + - match: { classification.multiclass_confusion_matrix._other_: 0 } +--- +"Test classification given missing actual_field": + - do: + catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "missing", + "predicted_field": "classification_field_pred.keyword" + } + } + } + +--- +"Test classification given missing predicted_field": + - do: + catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "classification": { + "actual_field": "classification_field_act.keyword", + "predicted_field": "missing" + } + } + } + --- "Test regression given evaluation with empty metrics": - do: