[7.x] [ML] Implement AucRoc metric for classification - HLRC (#62304) (#63058)

This commit is contained in:
Przemysław Witek 2020-09-30 14:04:10 +02:00 committed by GitHub
parent 0860746bf2
commit d677a2b8ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 647 additions and 319 deletions

View File

@ -19,13 +19,13 @@
package org.elasticsearch.client.ml.dataframe.evaluation; package org.elasticsearch.client.ml.dataframe.evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@ -63,34 +63,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluation metrics // Evaluation metrics
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)), new ParseField(
AucRocMetric::fromXContent), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)), new ParseField(
PrecisionMetric::fromXContent), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)), new ParseField(
RecallMetric::fromXContent), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)), new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
ConfusionMatrixMetric::fromXContent), ConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
AucRocMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
AccuracyMetric::fromXContent), AccuracyMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName( new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)), PrecisionMetric::fromXContent),
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName( new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)), RecallMetric::fromXContent),
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)), new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
@ -114,34 +122,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluation metrics results // Evaluation metrics results
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)), new ParseField(
AucRocMetric.Result::fromXContent), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)), new ParseField(
PrecisionMetric.Result::fromXContent), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)), new ParseField(
RecallMetric.Result::fromXContent), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)), new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
ConfusionMatrixMetric.Result::fromXContent), ConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
AucRocMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)), new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
AccuracyMetric.Result::fromXContent), AccuracyMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName( new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)), PrecisionMetric.Result::fromXContent),
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName( new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)), RecallMetric.Result::fromXContent),
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)), new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),

View File

@ -0,0 +1,264 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
* The ROC curve is a plot of the TPR (true positive rate) against
* the FPR (false positive rate) over a varying threshold.
*/
public class AucRocMetric implements EvaluationMetric {
public static final String NAME = "auc_roc";
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));
static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
}
public static AucRocMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public static AucRocMetric forClass(String className) {
return new AucRocMetric(className, false);
}
public static AucRocMetric forClassWithCurve(String className) {
return new AucRocMetric(className, true);
}
private final String className;
private final Boolean includeCurve;
public AucRocMetric(String className, Boolean includeCurve) {
this.className = Objects.requireNonNull(className);
this.includeCurve = includeCurve;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
if (includeCurve != null) {
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
}
builder.endObject();
return builder;
}
@Override
public String getName() {
return NAME;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRocMetric that = (AucRocMetric) o;
return Objects.equals(className, that.className)
&& Objects.equals(includeCurve, that.includeCurve);
}
@Override
public int hashCode() {
return Objects.hash(className, includeCurve);
}
public static class Result implements EvaluationMetric.Result {
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ParseField SCORE = new ParseField("score");
private static final ParseField DOC_COUNT = new ParseField("doc_count");
private static final ParseField CURVE = new ParseField("curve");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareLong(constructorArg(), DOC_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}
private final double score;
private final long docCount;
private final List<AucRocPoint> curve;
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.docCount = docCount;
this.curve = curve;
}
@Override
public String getMetricName() {
return NAME;
}
public double getScore() {
return score;
}
public long getDocCount() {
return docCount;
}
public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
builder.field(DOC_COUNT.getPreferredName(), docCount);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return score == that.score
&& docCount == that.docCount
&& Objects.equals(curve, that.curve);
}
@Override
public int hashCode() {
return Objects.hash(score, docCount, curve);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
public static final class AucRocPoint implements ToXContentObject {
public static AucRocPoint fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ParseField TPR = new ParseField("tpr");
private static final ParseField FPR = new ParseField("fpr");
private static final ParseField THRESHOLD = new ParseField("threshold");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_point",
true,
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
static {
PARSER.declareDouble(constructorArg(), TPR);
PARSER.declareDouble(constructorArg(), FPR);
PARSER.declareDouble(constructorArg(), THRESHOLD);
}
private final double tpr;
private final double fpr;
private final double threshold;
public AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
this.fpr = fpr;
this.threshold = threshold;
}
public double getTruePositiveRate() {
return tpr;
}
public double getFalsePositiveRate() {
return fpr;
}
public double getThreshold() {
return threshold;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder
.startObject()
.field(TPR.getPreferredName(), tpr)
.field(FPR.getPreferredName(), fpr)
.field(THRESHOLD.getPreferredName(), threshold)
.endObject();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRocPoint that = (AucRocPoint) o;
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
}
@Override
public int hashCode() {
return Objects.hash(tpr, fpr, threshold);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
}

View File

@ -45,15 +45,20 @@ public class Classification implements Evaluation {
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
private static final ParseField METRICS = new ParseField("metrics"); private static final ParseField METRICS = new ParseField("metrics");
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>( public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2])); NAME,
true,
a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List<EvaluationMetric>) a[3]));
static { static {
PARSER.declareString(constructorArg(), ACTUAL_FIELD); PARSER.declareString(constructorArg(), ACTUAL_FIELD);
PARSER.declareString(constructorArg(), PREDICTED_FIELD); PARSER.declareString(optionalConstructorArg(), PREDICTED_FIELD);
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_FIELD);
PARSER.declareNamedObjects( PARSER.declareNamedObjects(
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS); optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
} }
@ -64,32 +69,44 @@ public class Classification implements Evaluation {
/** /**
* The field containing the actual value * The field containing the actual value
* The value of this field is assumed to be numeric
*/ */
private final String actualField; private final String actualField;
/** /**
* The field containing the predicted value * The field containing the predicted value
* The value of this field is assumed to be numeric
*/ */
private final String predictedField; private final String predictedField;
/**
* The field containing the array of top classes
*/
private final String topClassesField;
/** /**
* The list of metrics to calculate * The list of metrics to calculate
*/ */
private final List<EvaluationMetric> metrics; private final List<EvaluationMetric> metrics;
public Classification(String actualField, String predictedField) { public Classification(String actualField,
this(actualField, predictedField, (List<EvaluationMetric>)null); String predictedField,
String topClassesField) {
this(actualField, predictedField, topClassesField, (List<EvaluationMetric>)null);
} }
public Classification(String actualField, String predictedField, EvaluationMetric... metrics) { public Classification(String actualField,
this(actualField, predictedField, Arrays.asList(metrics)); String predictedField,
String topClassesField,
EvaluationMetric... metrics) {
this(actualField, predictedField, topClassesField, Arrays.asList(metrics));
} }
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) { public Classification(String actualField,
@Nullable String predictedField,
@Nullable String topClassesField,
@Nullable List<EvaluationMetric> metrics) {
this.actualField = Objects.requireNonNull(actualField); this.actualField = Objects.requireNonNull(actualField);
this.predictedField = Objects.requireNonNull(predictedField); this.predictedField = predictedField;
this.topClassesField = topClassesField;
if (metrics != null) { if (metrics != null) {
metrics.sort(Comparator.comparing(EvaluationMetric::getName)); metrics.sort(Comparator.comparing(EvaluationMetric::getName));
} }
@ -105,8 +122,12 @@ public class Classification implements Evaluation {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField); builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); if (predictedField != null) {
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
}
if (topClassesField != null) {
builder.field(TOP_CLASSES_FIELD.getPreferredName(), topClassesField);
}
if (metrics != null) { if (metrics != null) {
builder.startObject(METRICS.getPreferredName()); builder.startObject(METRICS.getPreferredName());
for (EvaluationMetric metric : metrics) { for (EvaluationMetric metric : metrics) {
@ -126,11 +147,12 @@ public class Classification implements Evaluation {
Classification that = (Classification) o; Classification that = (Classification) o;
return Objects.equals(that.actualField, this.actualField) return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField) && Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.topClassesField, this.topClassesField)
&& Objects.equals(that.metrics, this.metrics); && Objects.equals(that.metrics, this.metrics);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(actualField, predictedField, metrics); return Objects.hash(actualField, predictedField, topClassesField, metrics);
} }
} }

View File

@ -19,21 +19,14 @@
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/** /**
@ -49,7 +42,7 @@ public class AucRocMetric implements EvaluationMetric {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER = public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0])); new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((Boolean) args[0]));
static { static {
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE); PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
@ -63,18 +56,20 @@ public class AucRocMetric implements EvaluationMetric {
return new AucRocMetric(true); return new AucRocMetric(true);
} }
private final boolean includeCurve; private final Boolean includeCurve;
public AucRocMetric(Boolean includeCurve) { public AucRocMetric(Boolean includeCurve) {
this.includeCurve = includeCurve == null ? false : includeCurve; this.includeCurve = includeCurve;
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder builder.startObject();
.startObject() if (includeCurve != null) {
.field(INCLUDE_CURVE.getPreferredName(), includeCurve) builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
.endObject(); }
builder.endObject();
return builder;
} }
@Override @Override
@ -94,148 +89,4 @@ public class AucRocMetric implements EvaluationMetric {
public int hashCode() { public int hashCode() {
return Objects.hash(includeCurve); return Objects.hash(includeCurve);
} }
public static class Result implements EvaluationMetric.Result {
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ParseField SCORE = new ParseField("score");
private static final ParseField CURVE = new ParseField("curve");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
static {
PARSER.declareDouble(constructorArg(), SCORE);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
}
private final double score;
private final List<AucRocPoint> curve;
public Result(double score, @Nullable List<AucRocPoint> curve) {
this.score = score;
this.curve = curve;
}
@Override
public String getMetricName() {
return NAME;
}
public double getScore() {
return score;
}
public List<AucRocPoint> getCurve() {
return curve == null ? null : Collections.unmodifiableList(curve);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(SCORE.getPreferredName(), score);
if (curve != null && curve.isEmpty() == false) {
builder.field(CURVE.getPreferredName(), curve);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(score, that.score)
&& Objects.equals(curve, that.curve);
}
@Override
public int hashCode() {
return Objects.hash(score, curve);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
public static final class AucRocPoint implements ToXContentObject {
public static AucRocPoint fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ParseField TPR = new ParseField("tpr");
private static final ParseField FPR = new ParseField("fpr");
private static final ParseField THRESHOLD = new ParseField("threshold");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
new ConstructingObjectParser<>(
"auc_roc_point",
true,
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
static {
PARSER.declareDouble(constructorArg(), TPR);
PARSER.declareDouble(constructorArg(), FPR);
PARSER.declareDouble(constructorArg(), THRESHOLD);
}
private final double tpr;
private final double fpr;
private final double threshold;
public AucRocPoint(double tpr, double fpr, double threshold) {
this.tpr = tpr;
this.fpr = fpr;
this.threshold = threshold;
}
public double getTruePositiveRate() {
return tpr;
}
public double getFalsePositiveRate() {
return fpr;
}
public double getThreshold() {
return threshold;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder
.startObject()
.field(TPR.getPreferredName(), tpr)
.field(FPR.getPreferredName(), fpr)
.field(THRESHOLD.getPreferredName(), threshold)
.endObject();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AucRocPoint that = (AucRocPoint) o;
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
}
@Override
public int hashCode() {
return Objects.hash(tpr, fpr, threshold);
}
@Override
public String toString() {
return Strings.toString(this);
}
}
} }

View File

@ -138,13 +138,13 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@ -1774,15 +1774,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new OutlierDetection( new OutlierDetection(
actualField, actualField,
probabilityField, probabilityField,
PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve())); org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
ConfusionMatrixMetric.at(0.5),
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME); org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); evaluateDataFrameResponse.getMetricByName(
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME);
assertThat(
precisionResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME));
// Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9) // Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9)
assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9)); assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9));
// Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9)
@ -1791,8 +1798,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9)); assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9));
assertNull(precisionResult.getScoreByThreshold("0.1")); assertNull(precisionResult.getScoreByThreshold("0.1"));
RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME); org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result recallResult =
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME);
assertThat(
recallResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME));
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9)); assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9));
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
@ -1808,7 +1818,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
assertNull(confusionMatrixResult.getScoreByThreshold("0.1")); assertNull(confusionMatrixResult.getScoreByThreshold("0.1"));
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); AucRocMetric.Result aucRocResult =
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9)); assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9));
assertNotNull(aucRocResult.getCurve()); assertNotNull(aucRocResult.getCurve());
@ -1920,24 +1931,40 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
createIndex(indexName, mappingForClassification()); createIndex(indexName, mappingForClassification());
BulkRequest regressionBulk = new BulkRequest() BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForClassification(indexName, "cat", "cat")) .add(docForClassification(indexName, "cat", "cat", 0.9))
.add(docForClassification(indexName, "cat", "cat")) .add(docForClassification(indexName, "cat", "cat", 0.85))
.add(docForClassification(indexName, "cat", "cat")) .add(docForClassification(indexName, "cat", "cat", 0.95))
.add(docForClassification(indexName, "cat", "dog")) .add(docForClassification(indexName, "cat", "dog", 0.4))
.add(docForClassification(indexName, "cat", "fish")) .add(docForClassification(indexName, "cat", "fish", 0.35))
.add(docForClassification(indexName, "dog", "cat")) .add(docForClassification(indexName, "dog", "cat", 0.5))
.add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog", 0.4))
.add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog", 0.35))
.add(docForClassification(indexName, "dog", "dog")) .add(docForClassification(indexName, "dog", "dog", 0.6))
.add(docForClassification(indexName, "ant", "cat")); .add(docForClassification(indexName, "ant", "cat", 0.1));
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
{ // AucRoc
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName, null, new Classification(actualClassField, null, topClassesField, AucRocMetric.forClassWithCurve("cat")));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
assertThat(aucRocResult.getDocCount(), equalTo(5L));
assertNotNull(aucRocResult.getCurve());
}
{ // Accuracy { // Accuracy
EvaluateDataFrameRequest evaluateDataFrameRequest = EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest( new EvaluateDataFrameRequest(
indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric())); indexName, null, new Classification(actualClassField, predictedClassField, null, new AccuracyMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@ -1961,65 +1988,47 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
{ // Precision { // Precision
EvaluateDataFrameRequest evaluateDataFrameRequest = EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest( new EvaluateDataFrameRequest(
indexName, indexName, null, new Classification(actualClassField, predictedClassField, null, new PrecisionMetric()));
null,
new Classification(
actualClassField,
predictedClassField,
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult = PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
evaluateDataFrameResponse.getMetricByName( assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
assertThat(
precisionResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
assertThat( assertThat(
precisionResult.getClasses(), precisionResult.getClasses(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly // 3 out of 5 examples labeled as "cat" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6), new PrecisionMetric.PerClassResult("cat", 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly // 3 out of 4 examples labeled as "dog" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75)))); new PrecisionMetric.PerClassResult("dog", 0.75))));
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675)); assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
} }
{ // Recall { // Recall
EvaluateDataFrameRequest evaluateDataFrameRequest = EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest( new EvaluateDataFrameRequest(
indexName, indexName, null, new Classification(actualClassField, predictedClassField, null, new RecallMetric()));
null,
new Classification(
actualClassField,
predictedClassField,
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME)); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult = RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
evaluateDataFrameResponse.getMetricByName( assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
assertThat(
recallResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
assertThat( assertThat(
recallResult.getClasses(), recallResult.getClasses(),
equalTo( equalTo(
Arrays.asList( Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly // 3 out of 5 examples labeled as "cat" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6), new RecallMetric.PerClassResult("cat", 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly // 3 out of 4 examples labeled as "dog" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75), new RecallMetric.PerClassResult("dog", 0.75),
// no examples labeled as "ant" were classified correctly // no examples labeled as "ant" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0)))); new RecallMetric.PerClassResult("ant", 0.0))));
assertThat(recallResult.getAvgRecall(), equalTo(0.45)); assertThat(recallResult.getAvgRecall(), equalTo(0.45));
} }
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead { // No size provided for MulticlassConfusionMatrixMetric, default used instead
@ -2027,7 +2036,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new EvaluateDataFrameRequest( new EvaluateDataFrameRequest(
indexName, indexName,
null, null,
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric())); new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@ -2072,7 +2081,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new EvaluateDataFrameRequest( new EvaluateDataFrameRequest(
indexName, indexName,
null, null,
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2))); new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric(2)));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@ -2146,6 +2155,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
private static final String actualClassField = "actual_class"; private static final String actualClassField = "actual_class";
private static final String predictedClassField = "predicted_class"; private static final String predictedClassField = "predicted_class";
private static final String topClassesField = "top_classes";
private static XContentBuilder mappingForClassification() throws IOException { private static XContentBuilder mappingForClassification() throws IOException {
return XContentFactory.jsonBuilder().startObject() return XContentFactory.jsonBuilder().startObject()
@ -2156,14 +2166,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.startObject(predictedClassField) .startObject(predictedClassField)
.field("type", "keyword") .field("type", "keyword")
.endObject() .endObject()
.startObject(topClassesField)
.field("type", "nested")
.endObject()
.endObject() .endObject()
.endObject(); .endObject();
} }
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) { private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
return new IndexRequest() return new IndexRequest()
.index(indexName) .index(indexName)
.source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass); .source(XContentType.JSON,
actualClassField, actualClass,
predictedClassField, predictedClass,
topClassesField, Arrays.asList(
new HashMap<String, Object>() {{
put("class_name", predictedClass);
put("class_probability", p);
}},
new HashMap<String, Object>() {{
put("class_name", "other");
put("class_probability", 1 - p);
}}));
} }
private static final String actualRegression = "regression_actual"; private static final String actualRegression = "regression_actual";

View File

@ -59,13 +59,13 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction;
import org.elasticsearch.client.indexlifecycle.WaitForSnapshotAction; import org.elasticsearch.client.indexlifecycle.WaitForSnapshotAction;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@ -707,7 +707,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() { public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents(); List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(73, namedXContents.size()); assertEquals(75, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>(); Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>(); List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) { for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -756,35 +756,39 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(TimeSyncConfig.NAME)); assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME)); assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names, assertThat(names,
hasItems( hasItems(
registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME), registeredMetricName(
registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME), OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME),
registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME),
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME),
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME), registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
registeredMetricName(Classification.NAME, AucRocMetric.NAME),
registeredMetricName(Classification.NAME, AccuracyMetric.NAME), registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
registeredMetricName( registeredMetricName(Classification.NAME, PrecisionMetric.NAME),
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME), registeredMetricName(Classification.NAME, RecallMetric.NAME),
registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, HuberMetric.NAME), registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names, assertThat(names,
hasItems( hasItems(
registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME), registeredMetricName(
registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME), OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME),
registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME), registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME),
registeredMetricName(
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME),
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME), registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
registeredMetricName(Classification.NAME, AucRocMetric.NAME),
registeredMetricName(Classification.NAME, AccuracyMetric.NAME), registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
registeredMetricName( registeredMetricName(Classification.NAME, PrecisionMetric.NAME),
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME), registeredMetricName(Classification.NAME, RecallMetric.NAME),
registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),

View File

@ -156,15 +156,15 @@ import org.elasticsearch.client.ml.dataframe.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@ -201,6 +201,7 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
import org.elasticsearch.client.ml.job.results.Influencer; import org.elasticsearch.client.ml.job.results.Influencer;
import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.results.OverallBucket;
import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.client.ml.job.stats.JobStats;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
@ -3326,7 +3327,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
30, TimeUnit.SECONDS); 30, TimeUnit.SECONDS);
} }
public void testEvaluateDataFrame() throws Exception { public void testEvaluateDataFrame_OutlierDetection() throws Exception {
String indexName = "evaluate-test-index"; String indexName = "evaluate-test-index";
CreateIndexRequest createIndexRequest = CreateIndexRequest createIndexRequest =
new CreateIndexRequest(indexName) new CreateIndexRequest(indexName)
@ -3363,10 +3364,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
"label", // <2> "label", // <2>
"p", // <3> "p", // <3>
// Evaluation metrics // <4> // Evaluation metrics // <4>
PrecisionMetric.at(0.4, 0.5, 0.6), // <5> org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
RecallMetric.at(0.5, 0.7), // <6> org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), // <6>
ConfusionMatrixMetric.at(0.5), // <7> ConfusionMatrixMetric.at(0.5), // <7>
AucRocMetric.withCurve()); // <8> org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()); // <8>
// end::evaluate-data-frame-evaluation-outlierdetection // end::evaluate-data-frame-evaluation-outlierdetection
// tag::evaluate-data-frame-request // tag::evaluate-data-frame-request
@ -3386,7 +3387,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// end::evaluate-data-frame-response // end::evaluate-data-frame-response
// tag::evaluate-data-frame-results-outlierdetection // tag::evaluate-data-frame-results-outlierdetection
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1> org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); // <1>
double precision = precisionResult.getScoreByThreshold("0.4"); // <2> double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3> ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
@ -3395,7 +3397,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
assertThat( assertThat(
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()), metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); containsInAnyOrder(
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME,
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME));
assertThat(precision, closeTo(0.6, 1e-9)); assertThat(precision, closeTo(0.6, 1e-9));
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4 assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4
@ -3409,10 +3415,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
new OutlierDetection( new OutlierDetection(
"label", "label",
"p", "p",
PrecisionMetric.at(0.4, 0.5, 0.6), org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
RecallMetric.at(0.5, 0.7), org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
ConfusionMatrixMetric.at(0.5), ConfusionMatrixMetric.at(0.5),
AucRocMetric.withCurve())); org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
// tag::evaluate-data-frame-execute-listener // tag::evaluate-data-frame-execute-listener
ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<EvaluateDataFrameResponse>() { ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<EvaluateDataFrameResponse>() {
@ -3452,21 +3458,39 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.startObject("predicted_class") .startObject("predicted_class")
.field("type", "keyword") .field("type", "keyword")
.endObject() .endObject()
.startObject("ml.top_classes")
.field("type", "nested")
.endObject()
.endObject() .endObject()
.endObject()); .endObject());
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
return new IndexRequest()
.source(XContentType.JSON,
"actual_class", actualClass,
"predicted_class", predictedClass,
"ml.top_classes", Arrays.asList(
new HashMap<String, Object>() {{
put("class_name", predictedClass);
put("class_probability", p);
}},
new HashMap<String, Object>() {{
put("class_name", "other");
put("class_probability", 1 - p);
}}));
};
BulkRequest bulkRequest = BulkRequest bulkRequest =
new BulkRequest(indexName) new BulkRequest(indexName)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0 .add(indexRequest.apply("cat", "cat", 0.9)) // #0
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1 .add(indexRequest.apply("cat", "cat", 0.9)) // #1
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2 .add(indexRequest.apply("cat", "cat", 0.9)) // #2
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3 .add(indexRequest.apply("cat", "dog", 0.9)) // #3
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4 .add(indexRequest.apply("cat", "fox", 0.9)) // #4
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5 .add(indexRequest.apply("dog", "cat", 0.9)) // #5
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6 .add(indexRequest.apply("dog", "dog", 0.9)) // #6
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7 .add(indexRequest.apply("dog", "dog", 0.9)) // #7
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8 .add(indexRequest.apply("dog", "dog", 0.9)) // #8
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9 .add(indexRequest.apply("ant", "cat", 0.9)); // #9
RestHighLevelClient client = highLevelClient(); RestHighLevelClient client = highLevelClient();
client.indices().create(createIndexRequest, RequestOptions.DEFAULT); client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
client.bulk(bulkRequest, RequestOptions.DEFAULT); client.bulk(bulkRequest, RequestOptions.DEFAULT);
@ -3476,11 +3500,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1> new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
"actual_class", // <2> "actual_class", // <2>
"predicted_class", // <3> "predicted_class", // <3>
// Evaluation metrics // <4> "ml.top_classes", // <4>
new AccuracyMetric(), // <5> // Evaluation metrics // <5>
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6> new AccuracyMetric(), // <6>
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7> new PrecisionMetric(), // <7>
new MulticlassConfusionMatrixMetric(3)); // <8> new RecallMetric(), // <8>
new MulticlassConfusionMatrixMetric(3), // <9>
AucRocMetric.forClass("cat")); // <10>
// end::evaluate-data-frame-evaluation-classification // end::evaluate-data-frame-evaluation-classification
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation); EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@ -3490,12 +3516,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1> AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
double accuracy = accuracyResult.getOverallAccuracy(); // <2> double accuracy = accuracyResult.getOverallAccuracy(); // <2>
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult = PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <3>
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
double precision = precisionResult.getAvgPrecision(); // <4> double precision = precisionResult.getAvgPrecision(); // <4>
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult = RecallMetric.Result recallResult = response.getMetricByName(RecallMetric.NAME); // <5>
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
double recall = recallResult.getAvgRecall(); // <6> double recall = recallResult.getAvgRecall(); // <6>
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix = MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
@ -3503,19 +3527,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8> List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9> long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
double aucRocScore = aucRocResult.getScore(); // <11>
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
// end::evaluate-data-frame-results-classification // end::evaluate-data-frame-results-classification
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME)); assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(accuracy, equalTo(0.6)); assertThat(accuracy, equalTo(0.6));
assertThat( assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
precisionResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
assertThat(precision, equalTo(0.675)); assertThat(precision, equalTo(0.675));
assertThat( assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
recallResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
assertThat(recall, equalTo(0.45)); assertThat(recall, equalTo(0.45));
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME)); assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
@ -3539,6 +3563,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L)))); 0L))));
assertThat(otherClassesCount, equalTo(0L)); assertThat(otherClassesCount, equalTo(0L));
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
assertThat(aucRocScore, equalTo(0.2625));
assertThat(aucRocDocCount, equalTo(5L));
} }
} }

View File

@ -26,7 +26,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests;
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests; import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests;

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;

View File

@ -16,7 +16,7 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection; package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
@ -31,6 +31,7 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
public static AucRocMetric.Result randomResult() { public static AucRocMetric.Result randomResult() {
return new AucRocMetric.Result( return new AucRocMetric.Result(
randomDouble(), randomDouble(),
randomLong(),
Stream Stream
.generate(AucRocMetricAucRocPointTests::randomPoint) .generate(AucRocMetricAucRocPointTests::randomPoint)
.limit(randomIntBetween(1, 10)) .limit(randomIntBetween(1, 10))

View File

@ -0,0 +1,55 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static AucRocMetric createRandom() {
return new AucRocMetric(
randomAlphaOfLengthBetween(1, 10),
randomBoolean() ? randomBoolean() : null);
}
@Override
protected AucRocMetric createTestInstance() {
return createRandom();
}
@Override
protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
return AucRocMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -40,11 +40,16 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
List<EvaluationMetric> metrics = List<EvaluationMetric> metrics =
randomSubsetOf( randomSubsetOf(
Arrays.asList( Arrays.asList(
AucRocMetricTests.createRandom(),
AccuracyMetricTests.createRandom(), AccuracyMetricTests.createRandom(),
PrecisionMetricTests.createRandom(), PrecisionMetricTests.createRandom(),
RecallMetricTests.createRandom(), RecallMetricTests.createRandom(),
MulticlassConfusionMatrixMetricTests.createRandom())); MulticlassConfusionMatrixMetricTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); return new Classification(
randomAlphaOfLength(10),
randomBoolean() ? randomAlphaOfLength(10) : null,
randomBoolean() ? randomAlphaOfLength(10) : null,
metrics.isEmpty() ? null : metrics);
} }
@Override @Override

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static AucRocMetric createRandom() {
return new AucRocMetric(randomBoolean() ? randomBoolean() : null);
}
@Override
protected AucRocMetric createTestInstance() {
return createRandom();
}
@Override
protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
return AucRocMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -40,7 +40,7 @@ public class OutlierDetectionTests extends AbstractXContentTestCase<OutlierDetec
public static OutlierDetection createRandom() { public static OutlierDetection createRandom() {
List<EvaluationMetric> metrics = new ArrayList<>(); List<EvaluationMetric> metrics = new ArrayList<>();
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(new AucRocMetric(randomBoolean())); metrics.add(AucRocMetricTests.createRandom());
} }
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1, metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1,

View File

@ -51,11 +51,13 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
<1> Constructing a new evaluation <1> Constructing a new evaluation
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to. <2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example. <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
<4> The remaining parameters are the metrics to be calculated based on the two fields described above <4> Name of the field in the index. Its value denotes the array of top classes. Must be nested.
<5> Accuracy <5> The remaining parameters are the metrics to be calculated based on the two fields described above
<6> Precision <6> Accuracy
<7> Recall <7> Precision
<8> Multiclass confusion matrix of size 3 <8> Recall
<9> Multiclass confusion matrix of size 3
<10> {wikipedia}/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated for class "cat" treated as positive and the rest as negative
===== Regression ===== Regression
@ -115,6 +117,9 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
<7> Fetching multiclass confusion matrix metric by name <7> Fetching multiclass confusion matrix metric by name
<8> Fetching the contents of the confusion matrix <8> Fetching the contents of the confusion matrix
<9> Fetching the number of classes that were not included in the matrix <9> Fetching the number of classes that were not included in the matrix
<10> Fetching AucRoc metric by name
<11> Fetching the actual AucRoc score
<12> Fetching the number of documents that were used in order to calculate AucRoc score
===== Regression ===== Regression