diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 1110de4d644..e4534c5603b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -19,13 +19,13 @@ package org.elasticsearch.client.ml.dataframe.evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; @@ -63,34 +63,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluation metrics new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)), - AucRocMetric::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)), - PrecisionMetric::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)), - RecallMetric::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)), ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, + new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)), + AucRocMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), AccuracyMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)), - org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent), + new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)), + PrecisionMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)), - org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent), + new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)), + RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)), @@ -114,34 +122,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluation metrics results new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)), - AucRocMetric.Result::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)), - PrecisionMetric.Result::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)), - RecallMetric.Result::fromXContent), + new ParseField( + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)), ConfusionMatrixMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, + new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)), + AucRocMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), AccuracyMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)), - org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent), + new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)), + PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)), - org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent), + new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)), + RecallMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java new file mode 100644 index 00000000000..79cb13718a5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java @@ -0,0 +1,264 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.classification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + */ +public class AucRocMetric implements EvaluationMetric { + + public static final String NAME = "auc_roc"; + + public static final ParseField CLASS_NAME = new ParseField("class_name"); + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1])); + + static { + PARSER.declareString(constructorArg(), CLASS_NAME); + PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE); + } + + public static AucRocMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static AucRocMetric forClass(String className) { + return new AucRocMetric(className, false); + } + + public static AucRocMetric forClassWithCurve(String className) { + return new AucRocMetric(className, true); + } + + private final String className; + private final Boolean includeCurve; + + public AucRocMetric(String className, Boolean includeCurve) { + this.className = Objects.requireNonNull(className); + this.includeCurve = includeCurve; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + if (includeCurve != null) { + builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve); + } + builder.endObject(); + return builder; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocMetric that = (AucRocMetric) o; + return Objects.equals(className, that.className) + && Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(className, includeCurve); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField DOC_COUNT = new ParseField("doc_count"); + private static final ParseField CURVE = new ParseField("curve"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List) args[2])); + + static { + PARSER.declareDouble(constructorArg(), SCORE); + PARSER.declareLong(constructorArg(), DOC_COUNT); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); + } + + private final double score; + private final long docCount; + private final List curve; + + public Result(double score, long docCount, @Nullable List curve) { + this.score = score; + this.docCount = docCount; + this.curve = curve; + } + + @Override + public String getMetricName() { + return NAME; + } + + public double getScore() { + return score; + } + + public long getDocCount() { + return docCount; + } + + public List getCurve() { + return curve == null ? null : Collections.unmodifiableList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(SCORE.getPreferredName(), score); + builder.field(DOC_COUNT.getPreferredName(), docCount); + if (curve != null && curve.isEmpty() == false) { + builder.field(CURVE.getPreferredName(), curve); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return score == that.score + && docCount == that.docCount + && Objects.equals(curve, that.curve); + } + + @Override + public int hashCode() { + return Objects.hash(score, docCount, curve); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class AucRocPoint implements ToXContentObject { + + public static AucRocPoint fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TPR = new ParseField("tpr"); + private static final ParseField FPR = new ParseField("fpr"); + private static final ParseField THRESHOLD = new ParseField("threshold"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "auc_roc_point", + true, + args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); + + static { + PARSER.declareDouble(constructorArg(), TPR); + PARSER.declareDouble(constructorArg(), FPR); + PARSER.declareDouble(constructorArg(), THRESHOLD); + } + + private final double tpr; + private final double fpr; + private final double threshold; + + public AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + public double getTruePositiveRate() { + return tpr; + } + + public double getFalsePositiveRate() { + return fpr; + } + + public double getThreshold() { + return threshold; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TPR.getPreferredName(), tpr) + .field(FPR.getPreferredName(), fpr) + .field(THRESHOLD.getPreferredName(), threshold) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocPoint that = (AucRocPoint) o; + return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; + } + + @Override + public int hashCode() { + return Objects.hash(tpr, fpr, threshold); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java index f6407822898..3c7803da94b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java @@ -45,15 +45,20 @@ public class Classification implements Evaluation { private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field"); + private static final ParseField METRICS = new ParseField("metrics"); @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, true, a -> new Classification((String) a[0], (String) a[1], (List) a[2])); + NAME, + true, + a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List) a[3])); static { PARSER.declareString(constructorArg(), ACTUAL_FIELD); - PARSER.declareString(constructorArg(), PREDICTED_FIELD); + PARSER.declareString(optionalConstructorArg(), PREDICTED_FIELD); + PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_FIELD); PARSER.declareNamedObjects( optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS); } @@ -64,32 +69,44 @@ public class Classification implements Evaluation { /** * The field containing the actual value - * The value of this field is assumed to be numeric */ private final String actualField; /** * The field containing the predicted value - * The value of this field is assumed to be numeric */ private final String predictedField; + /** + * The field containing the array of top classes + */ + private final String topClassesField; + /** * The list of metrics to calculate */ private final List metrics; - public Classification(String actualField, String predictedField) { - this(actualField, predictedField, (List)null); + public Classification(String actualField, + String predictedField, + String topClassesField) { + this(actualField, predictedField, topClassesField, (List)null); } - public Classification(String actualField, String predictedField, EvaluationMetric... metrics) { - this(actualField, predictedField, Arrays.asList(metrics)); + public Classification(String actualField, + String predictedField, + String topClassesField, + EvaluationMetric... metrics) { + this(actualField, predictedField, topClassesField, Arrays.asList(metrics)); } - public Classification(String actualField, String predictedField, @Nullable List metrics) { + public Classification(String actualField, + @Nullable String predictedField, + @Nullable String topClassesField, + @Nullable List metrics) { this.actualField = Objects.requireNonNull(actualField); - this.predictedField = Objects.requireNonNull(predictedField); + this.predictedField = predictedField; + this.topClassesField = topClassesField; if (metrics != null) { metrics.sort(Comparator.comparing(EvaluationMetric::getName)); } @@ -105,8 +122,12 @@ public class Classification implements Evaluation { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ACTUAL_FIELD.getPreferredName(), actualField); - builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); - + if (predictedField != null) { + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + } + if (topClassesField != null) { + builder.field(TOP_CLASSES_FIELD.getPreferredName(), topClassesField); + } if (metrics != null) { builder.startObject(METRICS.getPreferredName()); for (EvaluationMetric metric : metrics) { @@ -126,11 +147,12 @@ public class Classification implements Evaluation { Classification that = (Classification) o; return Objects.equals(that.actualField, this.actualField) && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.topClassesField, this.topClassesField) && Objects.equals(that.metrics, this.metrics); } @Override public int hashCode() { - return Objects.hash(actualField, predictedField, metrics); + return Objects.hash(actualField, predictedField, topClassesField, metrics); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java index 959de6a97a8..76d8c514dae 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java @@ -19,21 +19,14 @@ package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; -import java.util.Collections; -import java.util.List; import java.util.Objects; -import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; /** @@ -49,7 +42,7 @@ public class AucRocMetric implements EvaluationMetric { @SuppressWarnings("unchecked") public static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0])); + new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((Boolean) args[0])); static { PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE); @@ -63,18 +56,20 @@ public class AucRocMetric implements EvaluationMetric { return new AucRocMetric(true); } - private final boolean includeCurve; + private final Boolean includeCurve; public AucRocMetric(Boolean includeCurve) { - this.includeCurve = includeCurve == null ? false : includeCurve; + this.includeCurve = includeCurve; } @Override - public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - return builder - .startObject() - .field(INCLUDE_CURVE.getPreferredName(), includeCurve) - .endObject(); + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (includeCurve != null) { + builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve); + } + builder.endObject(); + return builder; } @Override @@ -94,148 +89,4 @@ public class AucRocMetric implements EvaluationMetric { public int hashCode() { return Objects.hash(includeCurve); } - - public static class Result implements EvaluationMetric.Result { - - public static Result fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); - } - - private static final ParseField SCORE = new ParseField("score"); - private static final ParseField CURVE = new ParseField("curve"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List) args[1])); - - static { - PARSER.declareDouble(constructorArg(), SCORE); - PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); - } - - private final double score; - private final List curve; - - public Result(double score, @Nullable List curve) { - this.score = score; - this.curve = curve; - } - - @Override - public String getMetricName() { - return NAME; - } - - public double getScore() { - return score; - } - - public List getCurve() { - return curve == null ? null : Collections.unmodifiableList(curve); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.startObject(); - builder.field(SCORE.getPreferredName(), score); - if (curve != null && curve.isEmpty() == false) { - builder.field(CURVE.getPreferredName(), curve); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Result that = (Result) o; - return Objects.equals(score, that.score) - && Objects.equals(curve, that.curve); - } - - @Override - public int hashCode() { - return Objects.hash(score, curve); - } - - @Override - public String toString() { - return Strings.toString(this); - } - } - - public static final class AucRocPoint implements ToXContentObject { - - public static AucRocPoint fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); - } - - private static final ParseField TPR = new ParseField("tpr"); - private static final ParseField FPR = new ParseField("fpr"); - private static final ParseField THRESHOLD = new ParseField("threshold"); - - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "auc_roc_point", - true, - args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); - - static { - PARSER.declareDouble(constructorArg(), TPR); - PARSER.declareDouble(constructorArg(), FPR); - PARSER.declareDouble(constructorArg(), THRESHOLD); - } - - private final double tpr; - private final double fpr; - private final double threshold; - - public AucRocPoint(double tpr, double fpr, double threshold) { - this.tpr = tpr; - this.fpr = fpr; - this.threshold = threshold; - } - - public double getTruePositiveRate() { - return tpr; - } - - public double getFalsePositiveRate() { - return fpr; - } - - public double getThreshold() { - return threshold; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder - .startObject() - .field(TPR.getPreferredName(), tpr) - .field(FPR.getPreferredName(), fpr) - .field(THRESHOLD.getPreferredName(), threshold) - .endObject(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - AucRocPoint that = (AucRocPoint) o; - return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; - } - - @Override - public int hashCode() { - return Objects.hash(tpr, fpr, threshold); - } - - @Override - public String toString() { - return Strings.toString(this); - } - } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 36608760079..0d8c4d6ff39 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -138,13 +138,13 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; @@ -1774,15 +1774,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { new OutlierDetection( actualField, probabilityField, - PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve())); + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), + ConfusionMatrixMetric.at(0.5), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); - PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME); - assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult = + evaluateDataFrameResponse.getMetricByName( + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); + assertThat( + precisionResult.getMetricName(), + equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)); // Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9) assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9)); // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) @@ -1791,8 +1798,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9)); assertNull(precisionResult.getScoreByThreshold("0.1")); - RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME); - assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result recallResult = + evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME); + assertThat( + recallResult.getMetricName(), + equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)); // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9)); // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) @@ -1808,7 +1818,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 assertNull(confusionMatrixResult.getScoreByThreshold("0.1")); - AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); + AucRocMetric.Result aucRocResult = + evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9)); assertNotNull(aucRocResult.getCurve()); @@ -1920,24 +1931,40 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { createIndex(indexName, mappingForClassification()); BulkRequest regressionBulk = new BulkRequest() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(docForClassification(indexName, "cat", "cat")) - .add(docForClassification(indexName, "cat", "cat")) - .add(docForClassification(indexName, "cat", "cat")) - .add(docForClassification(indexName, "cat", "dog")) - .add(docForClassification(indexName, "cat", "fish")) - .add(docForClassification(indexName, "dog", "cat")) - .add(docForClassification(indexName, "dog", "dog")) - .add(docForClassification(indexName, "dog", "dog")) - .add(docForClassification(indexName, "dog", "dog")) - .add(docForClassification(indexName, "ant", "cat")); + .add(docForClassification(indexName, "cat", "cat", 0.9)) + .add(docForClassification(indexName, "cat", "cat", 0.85)) + .add(docForClassification(indexName, "cat", "cat", 0.95)) + .add(docForClassification(indexName, "cat", "dog", 0.4)) + .add(docForClassification(indexName, "cat", "fish", 0.35)) + .add(docForClassification(indexName, "dog", "cat", 0.5)) + .add(docForClassification(indexName, "dog", "dog", 0.4)) + .add(docForClassification(indexName, "dog", "dog", 0.35)) + .add(docForClassification(indexName, "dog", "dog", 0.6)) + .add(docForClassification(indexName, "ant", "cat", 0.1)); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + { // AucRoc + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, null, new Classification(actualClassField, null, topClassesField, AucRocMetric.forClassWithCurve("cat"))); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); + assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); + assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9)); + assertThat(aucRocResult.getDocCount(), equalTo(5L)); + assertNotNull(aucRocResult.getCurve()); + } { // Accuracy EvaluateDataFrameRequest evaluateDataFrameRequest = new EvaluateDataFrameRequest( - indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric())); + indexName, null, new Classification(actualClassField, predictedClassField, null, new AccuracyMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); @@ -1961,65 +1988,47 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { { // Precision EvaluateDataFrameRequest evaluateDataFrameRequest = new EvaluateDataFrameRequest( - indexName, - null, - new Classification( - actualClassField, - predictedClassField, - new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric())); + indexName, null, new Classification(actualClassField, predictedClassField, null, new PrecisionMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); - org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult = - evaluateDataFrameResponse.getMetricByName( - org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); - assertThat( - precisionResult.getMetricName(), - equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)); + PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME); + assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); assertThat( precisionResult.getClasses(), equalTo( Arrays.asList( // 3 out of 5 examples labeled as "cat" were classified correctly - new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6), + new PrecisionMetric.PerClassResult("cat", 0.6), // 3 out of 4 examples labeled as "dog" were classified correctly - new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75)))); + new PrecisionMetric.PerClassResult("dog", 0.75)))); assertThat(precisionResult.getAvgPrecision(), equalTo(0.675)); } { // Recall EvaluateDataFrameRequest evaluateDataFrameRequest = new EvaluateDataFrameRequest( - indexName, - null, - new Classification( - actualClassField, - predictedClassField, - new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric())); + indexName, null, new Classification(actualClassField, predictedClassField, null, new RecallMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); - org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult = - evaluateDataFrameResponse.getMetricByName( - org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); - assertThat( - recallResult.getMetricName(), - equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)); + RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME); + assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); assertThat( recallResult.getClasses(), equalTo( Arrays.asList( // 3 out of 5 examples labeled as "cat" were classified correctly - new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6), + new RecallMetric.PerClassResult("cat", 0.6), // 3 out of 4 examples labeled as "dog" were classified correctly - new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75), + new RecallMetric.PerClassResult("dog", 0.75), // no examples labeled as "ant" were classified correctly - new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0)))); + new RecallMetric.PerClassResult("ant", 0.0)))); assertThat(recallResult.getAvgRecall(), equalTo(0.45)); } { // No size provided for MulticlassConfusionMatrixMetric, default used instead @@ -2027,7 +2036,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { new EvaluateDataFrameRequest( indexName, null, - new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric())); + new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); @@ -2072,7 +2081,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { new EvaluateDataFrameRequest( indexName, null, - new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2))); + new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric(2))); EvaluateDataFrameResponse evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); @@ -2146,6 +2155,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { private static final String actualClassField = "actual_class"; private static final String predictedClassField = "predicted_class"; + private static final String topClassesField = "top_classes"; private static XContentBuilder mappingForClassification() throws IOException { return XContentFactory.jsonBuilder().startObject() @@ -2156,14 +2166,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .startObject(predictedClassField) .field("type", "keyword") .endObject() + .startObject(topClassesField) + .field("type", "nested") + .endObject() .endObject() .endObject(); } - private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) { + private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) { return new IndexRequest() .index(indexName) - .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass); + .source(XContentType.JSON, + actualClassField, actualClass, + predictedClassField, predictedClass, + topClassesField, Arrays.asList( + new HashMap() {{ + put("class_name", predictedClass); + put("class_probability", p); + }}, + new HashMap() {{ + put("class_name", "other"); + put("class_probability", 1 - p); + }})); } private static final String actualRegression = "regression_actual"; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 063a92ae15d..4616f9c8fb6 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -59,13 +59,13 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.indexlifecycle.WaitForSnapshotAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; @@ -707,7 +707,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(73, namedXContents.size()); + assertEquals(75, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -756,35 +756,39 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, hasItems( - registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME), - registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME), - registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME), registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME), + registeredMetricName(Classification.NAME, AucRocMetric.NAME), registeredMetricName(Classification.NAME, AccuracyMetric.NAME), - registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME), - registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME), + registeredMetricName(Classification.NAME, PrecisionMetric.NAME), + registeredMetricName(Classification.NAME, RecallMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, HuberMetric.NAME), registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); - assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, hasItems( - registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME), - registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME), - registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME), + registeredMetricName( + OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME), registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME), + registeredMetricName(Classification.NAME, AucRocMetric.NAME), registeredMetricName(Classification.NAME, AccuracyMetric.NAME), - registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME), - registeredMetricName( - Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME), + registeredMetricName(Classification.NAME, PrecisionMetric.NAME), + registeredMetricName(Classification.NAME, RecallMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index ccfa24c7fa1..3aaca3a86ae 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -156,15 +156,15 @@ import org.elasticsearch.client.ml.dataframe.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; @@ -201,6 +201,7 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition; import org.elasticsearch.client.ml.job.results.Influencer; import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; @@ -3326,7 +3327,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { 30, TimeUnit.SECONDS); } - public void testEvaluateDataFrame() throws Exception { + public void testEvaluateDataFrame_OutlierDetection() throws Exception { String indexName = "evaluate-test-index"; CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName) @@ -3363,10 +3364,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { "label", // <2> "p", // <3> // Evaluation metrics // <4> - PrecisionMetric.at(0.4, 0.5, 0.6), // <5> - RecallMetric.at(0.5, 0.7), // <6> + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), // <5> + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), // <6> ConfusionMatrixMetric.at(0.5), // <7> - AucRocMetric.withCurve()); // <8> + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()); // <8> // end::evaluate-data-frame-evaluation-outlierdetection // tag::evaluate-data-frame-request @@ -3386,7 +3387,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { // end::evaluate-data-frame-response // tag::evaluate-data-frame-results-outlierdetection - PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1> + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult = + response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); // <1> double precision = precisionResult.getScoreByThreshold("0.4"); // <2> ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3> @@ -3395,7 +3397,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { assertThat( metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()), - containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); + containsInAnyOrder( + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME, + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME, + ConfusionMatrixMetric.NAME, + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)); assertThat(precision, closeTo(0.6, 1e-9)); assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4 @@ -3409,10 +3415,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { new OutlierDetection( "label", "p", - PrecisionMetric.at(0.4, 0.5, 0.6), - RecallMetric.at(0.5, 0.7), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), - AucRocMetric.withCurve())); + org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve())); // tag::evaluate-data-frame-execute-listener ActionListener listener = new ActionListener() { @@ -3452,21 +3458,39 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .startObject("predicted_class") .field("type", "keyword") .endObject() + .startObject("ml.top_classes") + .field("type", "nested") + .endObject() .endObject() .endObject()); + TriFunction indexRequest = (actualClass, predictedClass, p) -> { + return new IndexRequest() + .source(XContentType.JSON, + "actual_class", actualClass, + "predicted_class", predictedClass, + "ml.top_classes", Arrays.asList( + new HashMap() {{ + put("class_name", predictedClass); + put("class_probability", p); + }}, + new HashMap() {{ + put("class_name", "other"); + put("class_probability", 1 - p); + }})); + }; BulkRequest bulkRequest = new BulkRequest(indexName) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8 - .add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9 + .add(indexRequest.apply("cat", "cat", 0.9)) // #0 + .add(indexRequest.apply("cat", "cat", 0.9)) // #1 + .add(indexRequest.apply("cat", "cat", 0.9)) // #2 + .add(indexRequest.apply("cat", "dog", 0.9)) // #3 + .add(indexRequest.apply("cat", "fox", 0.9)) // #4 + .add(indexRequest.apply("dog", "cat", 0.9)) // #5 + .add(indexRequest.apply("dog", "dog", 0.9)) // #6 + .add(indexRequest.apply("dog", "dog", 0.9)) // #7 + .add(indexRequest.apply("dog", "dog", 0.9)) // #8 + .add(indexRequest.apply("ant", "cat", 0.9)); // #9 RestHighLevelClient client = highLevelClient(); client.indices().create(createIndexRequest, RequestOptions.DEFAULT); client.bulk(bulkRequest, RequestOptions.DEFAULT); @@ -3476,11 +3500,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1> "actual_class", // <2> "predicted_class", // <3> - // Evaluation metrics // <4> - new AccuracyMetric(), // <5> - new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6> - new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7> - new MulticlassConfusionMatrixMetric(3)); // <8> + "ml.top_classes", // <4> + // Evaluation metrics // <5> + new AccuracyMetric(), // <6> + new PrecisionMetric(), // <7> + new RecallMetric(), // <8> + new MulticlassConfusionMatrixMetric(3), // <9> + AucRocMetric.forClass("cat")); // <10> // end::evaluate-data-frame-evaluation-classification EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); @@ -3490,12 +3516,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1> double accuracy = accuracyResult.getOverallAccuracy(); // <2> - org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult = - response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3> + PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <3> double precision = precisionResult.getAvgPrecision(); // <4> - org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult = - response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5> + RecallMetric.Result recallResult = response.getMetricByName(RecallMetric.NAME); // <5> double recall = recallResult.getAvgRecall(); // <6> MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = @@ -3503,19 +3527,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { List confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8> long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9> + + AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10> + double aucRocScore = aucRocResult.getScore(); // <11> + Long aucRocDocCount = aucRocResult.getDocCount(); // <12> // end::evaluate-data-frame-results-classification assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat(accuracy, equalTo(0.6)); - assertThat( - precisionResult.getMetricName(), - equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)); + assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); assertThat(precision, equalTo(0.675)); - assertThat( - recallResult.getMetricName(), - equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)); + assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); assertThat(recall, equalTo(0.45)); assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); @@ -3539,6 +3563,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)))); assertThat(otherClassesCount, equalTo(0L)); + + assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); + assertThat(aucRocScore, equalTo(0.2625)); + assertThat(aucRocDocCount, equalTo(5L)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index 7a05b904e71..50fe97b5195 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; -import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetricResultTests; +import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricAucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java similarity index 95% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricAucRocPointTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java index d85e8193cc1..d3242906e34 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricAucRocPointTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; +package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java similarity index 95% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java index bf4e3f749a5..98855648357 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; +package org.elasticsearch.client.ml.dataframe.evaluation.classification; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -31,6 +31,7 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static AucRocMetric createRandom() { + return new AucRocMetric( + randomAlphaOfLengthBetween(1, 10), + randomBoolean() ? randomBoolean() : null); + } + + @Override + protected AucRocMetric createTestInstance() { + return createRandom(); + } + + @Override + protected AucRocMetric doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java index 4e8ed73fd5e..e8bd7ed80ee 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java @@ -40,11 +40,16 @@ public class ClassificationTests extends AbstractXContentTestCase metrics = randomSubsetOf( Arrays.asList( + AucRocMetricTests.createRandom(), AccuracyMetricTests.createRandom(), PrecisionMetricTests.createRandom(), RecallMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom())); - return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); + return new Classification( + randomAlphaOfLength(10), + randomBoolean() ? randomAlphaOfLength(10) : null, + randomBoolean() ? randomAlphaOfLength(10) : null, + metrics.isEmpty() ? null : metrics); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricTests.java new file mode 100644 index 00000000000..e866bd1f3c9 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class AucRocMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static AucRocMetric createRandom() { + return new AucRocMetric(randomBoolean() ? randomBoolean() : null); + } + + @Override + protected AucRocMetric createTestInstance() { + return createRandom(); + } + + @Override + protected AucRocMetric doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java index 2f4a531551c..dedb3e59892 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java @@ -40,7 +40,7 @@ public class OutlierDetectionTests extends AbstractXContentTestCase metrics = new ArrayList<>(); if (randomBoolean()) { - metrics.add(new AucRocMetric(randomBoolean())); + metrics.add(AucRocMetricTests.createRandom()); } if (randomBoolean()) { metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1, diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc index 5c96fceed0c..3698e919545 100644 --- a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -51,11 +51,13 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification] <1> Constructing a new evaluation <2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to. <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example. -<4> The remaining parameters are the metrics to be calculated based on the two fields described above -<5> Accuracy -<6> Precision -<7> Recall -<8> Multiclass confusion matrix of size 3 +<4> Name of the field in the index. Its value denotes the array of top classes. Must be nested. +<5> The remaining parameters are the metrics to be calculated based on the two fields described above +<6> Accuracy +<7> Precision +<8> Recall +<9> Multiclass confusion matrix of size 3 +<10> {wikipedia}/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated for class "cat" treated as positive and the rest as negative ===== Regression @@ -115,6 +117,9 @@ include-tagged::{doc-tests-file}[{api}-results-classification] <7> Fetching multiclass confusion matrix metric by name <8> Fetching the contents of the confusion matrix <9> Fetching the number of classes that were not included in the matrix +<10> Fetching AucRoc metric by name +<11> Fetching the actual AucRoc score +<12> Fetching the number of documents that were used in order to calculate AucRoc score ===== Regression