This commit is contained in:
parent
0860746bf2
commit
d677a2b8ee
|
@ -19,13 +19,13 @@
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation;
|
package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
|
@ -63,34 +63,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
// Evaluation metrics
|
// Evaluation metrics
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)),
|
new ParseField(
|
||||||
AucRocMetric::fromXContent),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)),
|
new ParseField(
|
||||||
PrecisionMetric::fromXContent),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)),
|
new ParseField(
|
||||||
RecallMetric::fromXContent),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
|
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
|
||||||
ConfusionMatrixMetric::fromXContent),
|
ConfusionMatrixMetric::fromXContent),
|
||||||
|
new NamedXContentRegistry.Entry(
|
||||||
|
EvaluationMetric.class,
|
||||||
|
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
|
||||||
|
AucRocMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
||||||
AccuracyMetric::fromXContent),
|
AccuracyMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(
|
new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
|
PrecisionMetric::fromXContent),
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
|
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(
|
new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
|
RecallMetric::fromXContent),
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
|
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
||||||
|
@ -114,34 +122,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
// Evaluation metrics results
|
// Evaluation metrics results
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)),
|
new ParseField(
|
||||||
AucRocMetric.Result::fromXContent),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)),
|
new ParseField(
|
||||||
PrecisionMetric.Result::fromXContent),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)),
|
new ParseField(
|
||||||
RecallMetric.Result::fromXContent),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
|
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
|
||||||
ConfusionMatrixMetric.Result::fromXContent),
|
ConfusionMatrixMetric.Result::fromXContent),
|
||||||
|
new NamedXContentRegistry.Entry(
|
||||||
|
EvaluationMetric.Result.class,
|
||||||
|
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
|
||||||
|
AucRocMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
||||||
AccuracyMetric.Result::fromXContent),
|
AccuracyMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(
|
new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
|
PrecisionMetric.Result::fromXContent),
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
|
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(
|
new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
|
RecallMetric.Result::fromXContent),
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
|
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
||||||
|
|
|
@ -0,0 +1,264 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
|
||||||
|
* The ROC curve is a plot of the TPR (true positive rate) against
|
||||||
|
* the FPR (false positive rate) over a varying threshold.
|
||||||
|
*/
|
||||||
|
public class AucRocMetric implements EvaluationMetric {
|
||||||
|
|
||||||
|
public static final String NAME = "auc_roc";
|
||||||
|
|
||||||
|
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||||
|
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||||
|
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AucRocMetric fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AucRocMetric forClass(String className) {
|
||||||
|
return new AucRocMetric(className, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AucRocMetric forClassWithCurve(String className) {
|
||||||
|
return new AucRocMetric(className, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String className;
|
||||||
|
private final Boolean includeCurve;
|
||||||
|
|
||||||
|
public AucRocMetric(String className, Boolean includeCurve) {
|
||||||
|
this.className = Objects.requireNonNull(className);
|
||||||
|
this.includeCurve = includeCurve;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||||
|
if (includeCurve != null) {
|
||||||
|
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
AucRocMetric that = (AucRocMetric) o;
|
||||||
|
return Objects.equals(className, that.className)
|
||||||
|
&& Objects.equals(includeCurve, that.includeCurve);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(className, includeCurve);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Result implements EvaluationMetric.Result {
|
||||||
|
|
||||||
|
public static Result fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final ParseField SCORE = new ParseField("score");
|
||||||
|
private static final ParseField DOC_COUNT = new ParseField("doc_count");
|
||||||
|
private static final ParseField CURVE = new ParseField("curve");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>(
|
||||||
|
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareDouble(constructorArg(), SCORE);
|
||||||
|
PARSER.declareLong(constructorArg(), DOC_COUNT);
|
||||||
|
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final double score;
|
||||||
|
private final long docCount;
|
||||||
|
private final List<AucRocPoint> curve;
|
||||||
|
|
||||||
|
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
|
||||||
|
this.score = score;
|
||||||
|
this.docCount = docCount;
|
||||||
|
this.curve = curve;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMetricName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getScore() {
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getDocCount() {
|
||||||
|
return docCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<AucRocPoint> getCurve() {
|
||||||
|
return curve == null ? null : Collections.unmodifiableList(curve);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(SCORE.getPreferredName(), score);
|
||||||
|
builder.field(DOC_COUNT.getPreferredName(), docCount);
|
||||||
|
if (curve != null && curve.isEmpty() == false) {
|
||||||
|
builder.field(CURVE.getPreferredName(), curve);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
Result that = (Result) o;
|
||||||
|
return score == that.score
|
||||||
|
&& docCount == that.docCount
|
||||||
|
&& Objects.equals(curve, that.curve);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(score, docCount, curve);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Strings.toString(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final class AucRocPoint implements ToXContentObject {
|
||||||
|
|
||||||
|
public static AucRocPoint fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final ParseField TPR = new ParseField("tpr");
|
||||||
|
private static final ParseField FPR = new ParseField("fpr");
|
||||||
|
private static final ParseField THRESHOLD = new ParseField("threshold");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>(
|
||||||
|
"auc_roc_point",
|
||||||
|
true,
|
||||||
|
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareDouble(constructorArg(), TPR);
|
||||||
|
PARSER.declareDouble(constructorArg(), FPR);
|
||||||
|
PARSER.declareDouble(constructorArg(), THRESHOLD);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final double tpr;
|
||||||
|
private final double fpr;
|
||||||
|
private final double threshold;
|
||||||
|
|
||||||
|
public AucRocPoint(double tpr, double fpr, double threshold) {
|
||||||
|
this.tpr = tpr;
|
||||||
|
this.fpr = fpr;
|
||||||
|
this.threshold = threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getTruePositiveRate() {
|
||||||
|
return tpr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getFalsePositiveRate() {
|
||||||
|
return fpr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getThreshold() {
|
||||||
|
return threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
return builder
|
||||||
|
.startObject()
|
||||||
|
.field(TPR.getPreferredName(), tpr)
|
||||||
|
.field(FPR.getPreferredName(), fpr)
|
||||||
|
.field(THRESHOLD.getPreferredName(), threshold)
|
||||||
|
.endObject();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
AucRocPoint that = (AucRocPoint) o;
|
||||||
|
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(tpr, fpr, threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Strings.toString(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -45,15 +45,20 @@ public class Classification implements Evaluation {
|
||||||
|
|
||||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||||
|
private static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
|
||||||
|
|
||||||
private static final ParseField METRICS = new ParseField("metrics");
|
private static final ParseField METRICS = new ParseField("metrics");
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
||||||
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
NAME,
|
||||||
|
true,
|
||||||
|
a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List<EvaluationMetric>) a[3]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
|
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
|
||||||
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
|
PARSER.declareString(optionalConstructorArg(), PREDICTED_FIELD);
|
||||||
|
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_FIELD);
|
||||||
PARSER.declareNamedObjects(
|
PARSER.declareNamedObjects(
|
||||||
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
|
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
|
||||||
}
|
}
|
||||||
|
@ -64,32 +69,44 @@ public class Classification implements Evaluation {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The field containing the actual value
|
* The field containing the actual value
|
||||||
* The value of this field is assumed to be numeric
|
|
||||||
*/
|
*/
|
||||||
private final String actualField;
|
private final String actualField;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The field containing the predicted value
|
* The field containing the predicted value
|
||||||
* The value of this field is assumed to be numeric
|
|
||||||
*/
|
*/
|
||||||
private final String predictedField;
|
private final String predictedField;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The field containing the array of top classes
|
||||||
|
*/
|
||||||
|
private final String topClassesField;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The list of metrics to calculate
|
* The list of metrics to calculate
|
||||||
*/
|
*/
|
||||||
private final List<EvaluationMetric> metrics;
|
private final List<EvaluationMetric> metrics;
|
||||||
|
|
||||||
public Classification(String actualField, String predictedField) {
|
public Classification(String actualField,
|
||||||
this(actualField, predictedField, (List<EvaluationMetric>)null);
|
String predictedField,
|
||||||
|
String topClassesField) {
|
||||||
|
this(actualField, predictedField, topClassesField, (List<EvaluationMetric>)null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
|
public Classification(String actualField,
|
||||||
this(actualField, predictedField, Arrays.asList(metrics));
|
String predictedField,
|
||||||
|
String topClassesField,
|
||||||
|
EvaluationMetric... metrics) {
|
||||||
|
this(actualField, predictedField, topClassesField, Arrays.asList(metrics));
|
||||||
}
|
}
|
||||||
|
|
||||||
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
public Classification(String actualField,
|
||||||
|
@Nullable String predictedField,
|
||||||
|
@Nullable String topClassesField,
|
||||||
|
@Nullable List<EvaluationMetric> metrics) {
|
||||||
this.actualField = Objects.requireNonNull(actualField);
|
this.actualField = Objects.requireNonNull(actualField);
|
||||||
this.predictedField = Objects.requireNonNull(predictedField);
|
this.predictedField = predictedField;
|
||||||
|
this.topClassesField = topClassesField;
|
||||||
if (metrics != null) {
|
if (metrics != null) {
|
||||||
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
|
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
|
||||||
}
|
}
|
||||||
|
@ -105,8 +122,12 @@ public class Classification implements Evaluation {
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
if (predictedField != null) {
|
||||||
|
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||||
|
}
|
||||||
|
if (topClassesField != null) {
|
||||||
|
builder.field(TOP_CLASSES_FIELD.getPreferredName(), topClassesField);
|
||||||
|
}
|
||||||
if (metrics != null) {
|
if (metrics != null) {
|
||||||
builder.startObject(METRICS.getPreferredName());
|
builder.startObject(METRICS.getPreferredName());
|
||||||
for (EvaluationMetric metric : metrics) {
|
for (EvaluationMetric metric : metrics) {
|
||||||
|
@ -126,11 +147,12 @@ public class Classification implements Evaluation {
|
||||||
Classification that = (Classification) o;
|
Classification that = (Classification) o;
|
||||||
return Objects.equals(that.actualField, this.actualField)
|
return Objects.equals(that.actualField, this.actualField)
|
||||||
&& Objects.equals(that.predictedField, this.predictedField)
|
&& Objects.equals(that.predictedField, this.predictedField)
|
||||||
|
&& Objects.equals(that.topClassesField, this.topClassesField)
|
||||||
&& Objects.equals(that.metrics, this.metrics);
|
&& Objects.equals(that.metrics, this.metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(actualField, predictedField, metrics);
|
return Objects.hash(actualField, predictedField, topClassesField, metrics);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,21 +19,14 @@
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
||||||
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.common.Nullable;
|
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.Strings;
|
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -49,7 +42,7 @@ public class AucRocMetric implements EvaluationMetric {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
|
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
|
||||||
new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0]));
|
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((Boolean) args[0]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
|
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
|
||||||
|
@ -63,18 +56,20 @@ public class AucRocMetric implements EvaluationMetric {
|
||||||
return new AucRocMetric(true);
|
return new AucRocMetric(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final boolean includeCurve;
|
private final Boolean includeCurve;
|
||||||
|
|
||||||
public AucRocMetric(Boolean includeCurve) {
|
public AucRocMetric(Boolean includeCurve) {
|
||||||
this.includeCurve = includeCurve == null ? false : includeCurve;
|
this.includeCurve = includeCurve;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
return builder
|
builder.startObject();
|
||||||
.startObject()
|
if (includeCurve != null) {
|
||||||
.field(INCLUDE_CURVE.getPreferredName(), includeCurve)
|
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
|
||||||
.endObject();
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -94,148 +89,4 @@ public class AucRocMetric implements EvaluationMetric {
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(includeCurve);
|
return Objects.hash(includeCurve);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Result implements EvaluationMetric.Result {
|
|
||||||
|
|
||||||
public static Result fromXContent(XContentParser parser) {
|
|
||||||
return PARSER.apply(parser, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final ParseField SCORE = new ParseField("score");
|
|
||||||
private static final ParseField CURVE = new ParseField("curve");
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
|
||||||
new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
|
|
||||||
|
|
||||||
static {
|
|
||||||
PARSER.declareDouble(constructorArg(), SCORE);
|
|
||||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
|
|
||||||
}
|
|
||||||
|
|
||||||
private final double score;
|
|
||||||
private final List<AucRocPoint> curve;
|
|
||||||
|
|
||||||
public Result(double score, @Nullable List<AucRocPoint> curve) {
|
|
||||||
this.score = score;
|
|
||||||
this.curve = curve;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMetricName() {
|
|
||||||
return NAME;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getScore() {
|
|
||||||
return score;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<AucRocPoint> getCurve() {
|
|
||||||
return curve == null ? null : Collections.unmodifiableList(curve);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
|
||||||
builder.startObject();
|
|
||||||
builder.field(SCORE.getPreferredName(), score);
|
|
||||||
if (curve != null && curve.isEmpty() == false) {
|
|
||||||
builder.field(CURVE.getPreferredName(), curve);
|
|
||||||
}
|
|
||||||
builder.endObject();
|
|
||||||
return builder;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) return true;
|
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
|
||||||
Result that = (Result) o;
|
|
||||||
return Objects.equals(score, that.score)
|
|
||||||
&& Objects.equals(curve, that.curve);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Objects.hash(score, curve);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return Strings.toString(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static final class AucRocPoint implements ToXContentObject {
|
|
||||||
|
|
||||||
public static AucRocPoint fromXContent(XContentParser parser) {
|
|
||||||
return PARSER.apply(parser, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final ParseField TPR = new ParseField("tpr");
|
|
||||||
private static final ParseField FPR = new ParseField("fpr");
|
|
||||||
private static final ParseField THRESHOLD = new ParseField("threshold");
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
|
|
||||||
new ConstructingObjectParser<>(
|
|
||||||
"auc_roc_point",
|
|
||||||
true,
|
|
||||||
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
|
|
||||||
|
|
||||||
static {
|
|
||||||
PARSER.declareDouble(constructorArg(), TPR);
|
|
||||||
PARSER.declareDouble(constructorArg(), FPR);
|
|
||||||
PARSER.declareDouble(constructorArg(), THRESHOLD);
|
|
||||||
}
|
|
||||||
|
|
||||||
private final double tpr;
|
|
||||||
private final double fpr;
|
|
||||||
private final double threshold;
|
|
||||||
|
|
||||||
public AucRocPoint(double tpr, double fpr, double threshold) {
|
|
||||||
this.tpr = tpr;
|
|
||||||
this.fpr = fpr;
|
|
||||||
this.threshold = threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getTruePositiveRate() {
|
|
||||||
return tpr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getFalsePositiveRate() {
|
|
||||||
return fpr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getThreshold() {
|
|
||||||
return threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
||||||
return builder
|
|
||||||
.startObject()
|
|
||||||
.field(TPR.getPreferredName(), tpr)
|
|
||||||
.field(FPR.getPreferredName(), fpr)
|
|
||||||
.field(THRESHOLD.getPreferredName(), threshold)
|
|
||||||
.endObject();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o) return true;
|
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
|
||||||
AucRocPoint that = (AucRocPoint) o;
|
|
||||||
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return Objects.hash(tpr, fpr, threshold);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return Strings.toString(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,13 +138,13 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
||||||
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
||||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
|
@ -1774,15 +1774,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
new OutlierDetection(
|
new OutlierDetection(
|
||||||
actualField,
|
actualField,
|
||||||
probabilityField,
|
probabilityField,
|
||||||
PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve()));
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
|
||||||
|
ConfusionMatrixMetric.at(0.5),
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME));
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME));
|
||||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
|
||||||
|
|
||||||
PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
|
||||||
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
evaluateDataFrameResponse.getMetricByName(
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME);
|
||||||
|
assertThat(
|
||||||
|
precisionResult.getMetricName(),
|
||||||
|
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME));
|
||||||
// Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9)
|
// Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9)
|
||||||
assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9));
|
assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9));
|
||||||
// Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9)
|
// Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9)
|
||||||
|
@ -1791,8 +1798,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9));
|
assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9));
|
||||||
assertNull(precisionResult.getScoreByThreshold("0.1"));
|
assertNull(precisionResult.getScoreByThreshold("0.1"));
|
||||||
|
|
||||||
RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result recallResult =
|
||||||
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME);
|
||||||
|
assertThat(
|
||||||
|
recallResult.getMetricName(),
|
||||||
|
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME));
|
||||||
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
||||||
assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9));
|
assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9));
|
||||||
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
||||||
|
@ -1808,7 +1818,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
|
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
|
||||||
assertNull(confusionMatrixResult.getScoreByThreshold("0.1"));
|
assertNull(confusionMatrixResult.getScoreByThreshold("0.1"));
|
||||||
|
|
||||||
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
AucRocMetric.Result aucRocResult =
|
||||||
|
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
|
||||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||||
assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9));
|
assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9));
|
||||||
assertNotNull(aucRocResult.getCurve());
|
assertNotNull(aucRocResult.getCurve());
|
||||||
|
@ -1920,24 +1931,40 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
createIndex(indexName, mappingForClassification());
|
createIndex(indexName, mappingForClassification());
|
||||||
BulkRequest regressionBulk = new BulkRequest()
|
BulkRequest regressionBulk = new BulkRequest()
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||||
.add(docForClassification(indexName, "cat", "cat"))
|
.add(docForClassification(indexName, "cat", "cat", 0.9))
|
||||||
.add(docForClassification(indexName, "cat", "cat"))
|
.add(docForClassification(indexName, "cat", "cat", 0.85))
|
||||||
.add(docForClassification(indexName, "cat", "cat"))
|
.add(docForClassification(indexName, "cat", "cat", 0.95))
|
||||||
.add(docForClassification(indexName, "cat", "dog"))
|
.add(docForClassification(indexName, "cat", "dog", 0.4))
|
||||||
.add(docForClassification(indexName, "cat", "fish"))
|
.add(docForClassification(indexName, "cat", "fish", 0.35))
|
||||||
.add(docForClassification(indexName, "dog", "cat"))
|
.add(docForClassification(indexName, "dog", "cat", 0.5))
|
||||||
.add(docForClassification(indexName, "dog", "dog"))
|
.add(docForClassification(indexName, "dog", "dog", 0.4))
|
||||||
.add(docForClassification(indexName, "dog", "dog"))
|
.add(docForClassification(indexName, "dog", "dog", 0.35))
|
||||||
.add(docForClassification(indexName, "dog", "dog"))
|
.add(docForClassification(indexName, "dog", "dog", 0.6))
|
||||||
.add(docForClassification(indexName, "ant", "cat"));
|
.add(docForClassification(indexName, "ant", "cat", 0.1));
|
||||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||||
|
|
||||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
|
|
||||||
|
{ // AucRoc
|
||||||
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||||
|
new EvaluateDataFrameRequest(
|
||||||
|
indexName, null, new Classification(actualClassField, null, topClassesField, AucRocMetric.forClassWithCurve("cat")));
|
||||||
|
|
||||||
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||||
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||||
|
|
||||||
|
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
||||||
|
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||||
|
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
|
||||||
|
assertThat(aucRocResult.getDocCount(), equalTo(5L));
|
||||||
|
assertNotNull(aucRocResult.getCurve());
|
||||||
|
}
|
||||||
{ // Accuracy
|
{ // Accuracy
|
||||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||||
new EvaluateDataFrameRequest(
|
new EvaluateDataFrameRequest(
|
||||||
indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
|
indexName, null, new Classification(actualClassField, predictedClassField, null, new AccuracyMetric()));
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
|
@ -1961,65 +1988,47 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
{ // Precision
|
{ // Precision
|
||||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||||
new EvaluateDataFrameRequest(
|
new EvaluateDataFrameRequest(
|
||||||
indexName,
|
indexName, null, new Classification(actualClassField, predictedClassField, null, new PrecisionMetric()));
|
||||||
null,
|
|
||||||
new Classification(
|
|
||||||
actualClassField,
|
|
||||||
predictedClassField,
|
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
|
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||||
|
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
|
||||||
evaluateDataFrameResponse.getMetricByName(
|
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
|
|
||||||
assertThat(
|
|
||||||
precisionResult.getMetricName(),
|
|
||||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
|
||||||
assertThat(
|
assertThat(
|
||||||
precisionResult.getClasses(),
|
precisionResult.getClasses(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
|
new PrecisionMetric.PerClassResult("cat", 0.6),
|
||||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
|
new PrecisionMetric.PerClassResult("dog", 0.75))));
|
||||||
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
|
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
|
||||||
}
|
}
|
||||||
{ // Recall
|
{ // Recall
|
||||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||||
new EvaluateDataFrameRequest(
|
new EvaluateDataFrameRequest(
|
||||||
indexName,
|
indexName, null, new Classification(actualClassField, predictedClassField, null, new RecallMetric()));
|
||||||
null,
|
|
||||||
new Classification(
|
|
||||||
actualClassField,
|
|
||||||
predictedClassField,
|
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
|
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||||
|
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
|
||||||
evaluateDataFrameResponse.getMetricByName(
|
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
|
|
||||||
assertThat(
|
|
||||||
recallResult.getMetricName(),
|
|
||||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
|
||||||
assertThat(
|
assertThat(
|
||||||
recallResult.getClasses(),
|
recallResult.getClasses(),
|
||||||
equalTo(
|
equalTo(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
|
new RecallMetric.PerClassResult("cat", 0.6),
|
||||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
|
new RecallMetric.PerClassResult("dog", 0.75),
|
||||||
// no examples labeled as "ant" were classified correctly
|
// no examples labeled as "ant" were classified correctly
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
|
new RecallMetric.PerClassResult("ant", 0.0))));
|
||||||
assertThat(recallResult.getAvgRecall(), equalTo(0.45));
|
assertThat(recallResult.getAvgRecall(), equalTo(0.45));
|
||||||
}
|
}
|
||||||
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
||||||
|
@ -2027,7 +2036,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
new EvaluateDataFrameRequest(
|
new EvaluateDataFrameRequest(
|
||||||
indexName,
|
indexName,
|
||||||
null,
|
null,
|
||||||
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
|
new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric()));
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
|
@ -2072,7 +2081,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
new EvaluateDataFrameRequest(
|
new EvaluateDataFrameRequest(
|
||||||
indexName,
|
indexName,
|
||||||
null,
|
null,
|
||||||
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
|
new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric(2)));
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||||
|
@ -2146,6 +2155,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
|
|
||||||
private static final String actualClassField = "actual_class";
|
private static final String actualClassField = "actual_class";
|
||||||
private static final String predictedClassField = "predicted_class";
|
private static final String predictedClassField = "predicted_class";
|
||||||
|
private static final String topClassesField = "top_classes";
|
||||||
|
|
||||||
private static XContentBuilder mappingForClassification() throws IOException {
|
private static XContentBuilder mappingForClassification() throws IOException {
|
||||||
return XContentFactory.jsonBuilder().startObject()
|
return XContentFactory.jsonBuilder().startObject()
|
||||||
|
@ -2156,14 +2166,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
.startObject(predictedClassField)
|
.startObject(predictedClassField)
|
||||||
.field("type", "keyword")
|
.field("type", "keyword")
|
||||||
.endObject()
|
.endObject()
|
||||||
|
.startObject(topClassesField)
|
||||||
|
.field("type", "nested")
|
||||||
|
.endObject()
|
||||||
.endObject()
|
.endObject()
|
||||||
.endObject();
|
.endObject();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
|
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
|
||||||
return new IndexRequest()
|
return new IndexRequest()
|
||||||
.index(indexName)
|
.index(indexName)
|
||||||
.source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
|
.source(XContentType.JSON,
|
||||||
|
actualClassField, actualClass,
|
||||||
|
predictedClassField, predictedClass,
|
||||||
|
topClassesField, Arrays.asList(
|
||||||
|
new HashMap<String, Object>() {{
|
||||||
|
put("class_name", predictedClass);
|
||||||
|
put("class_probability", p);
|
||||||
|
}},
|
||||||
|
new HashMap<String, Object>() {{
|
||||||
|
put("class_name", "other");
|
||||||
|
put("class_probability", 1 - p);
|
||||||
|
}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final String actualRegression = "regression_actual";
|
private static final String actualRegression = "regression_actual";
|
||||||
|
|
|
@ -59,13 +59,13 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction;
|
||||||
import org.elasticsearch.client.indexlifecycle.WaitForSnapshotAction;
|
import org.elasticsearch.client.indexlifecycle.WaitForSnapshotAction;
|
||||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
|
@ -707,7 +707,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||||
|
|
||||||
public void testProvidedNamedXContents() {
|
public void testProvidedNamedXContents() {
|
||||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||||
assertEquals(73, namedXContents.size());
|
assertEquals(75, namedXContents.size());
|
||||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||||
List<String> names = new ArrayList<>();
|
List<String> names = new ArrayList<>();
|
||||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||||
|
@ -756,35 +756,39 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||||
assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME));
|
assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME));
|
||||||
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||||
assertThat(names,
|
assertThat(names,
|
||||||
hasItems(
|
hasItems(
|
||||||
registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME),
|
registeredMetricName(
|
||||||
registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME),
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME),
|
||||||
registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME),
|
||||||
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME),
|
||||||
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
|
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
|
||||||
|
registeredMetricName(Classification.NAME, AucRocMetric.NAME),
|
||||||
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
||||||
registeredMetricName(
|
registeredMetricName(Classification.NAME, PrecisionMetric.NAME),
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
|
registeredMetricName(Classification.NAME, RecallMetric.NAME),
|
||||||
registeredMetricName(
|
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
|
||||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, HuberMetric.NAME),
|
registeredMetricName(Regression.NAME, HuberMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||||
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||||
assertThat(names,
|
assertThat(names,
|
||||||
hasItems(
|
hasItems(
|
||||||
registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME),
|
registeredMetricName(
|
||||||
registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME),
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME),
|
||||||
registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME),
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME),
|
||||||
|
registeredMetricName(
|
||||||
|
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME),
|
||||||
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
|
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
|
||||||
|
registeredMetricName(Classification.NAME, AucRocMetric.NAME),
|
||||||
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
||||||
registeredMetricName(
|
registeredMetricName(Classification.NAME, PrecisionMetric.NAME),
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
|
registeredMetricName(Classification.NAME, RecallMetric.NAME),
|
||||||
registeredMetricName(
|
|
||||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
|
||||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||||
|
|
|
@ -156,15 +156,15 @@ import org.elasticsearch.client.ml.dataframe.Regression;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
|
@ -201,6 +201,7 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
|
||||||
import org.elasticsearch.client.ml.job.results.Influencer;
|
import org.elasticsearch.client.ml.job.results.Influencer;
|
||||||
import org.elasticsearch.client.ml.job.results.OverallBucket;
|
import org.elasticsearch.client.ml.job.results.OverallBucket;
|
||||||
import org.elasticsearch.client.ml.job.stats.JobStats;
|
import org.elasticsearch.client.ml.job.stats.JobStats;
|
||||||
|
import org.elasticsearch.common.TriFunction;
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
|
@ -3326,7 +3327,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
30, TimeUnit.SECONDS);
|
30, TimeUnit.SECONDS);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluateDataFrame() throws Exception {
|
public void testEvaluateDataFrame_OutlierDetection() throws Exception {
|
||||||
String indexName = "evaluate-test-index";
|
String indexName = "evaluate-test-index";
|
||||||
CreateIndexRequest createIndexRequest =
|
CreateIndexRequest createIndexRequest =
|
||||||
new CreateIndexRequest(indexName)
|
new CreateIndexRequest(indexName)
|
||||||
|
@ -3363,10 +3364,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
"label", // <2>
|
"label", // <2>
|
||||||
"p", // <3>
|
"p", // <3>
|
||||||
// Evaluation metrics // <4>
|
// Evaluation metrics // <4>
|
||||||
PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
|
||||||
RecallMetric.at(0.5, 0.7), // <6>
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), // <6>
|
||||||
ConfusionMatrixMetric.at(0.5), // <7>
|
ConfusionMatrixMetric.at(0.5), // <7>
|
||||||
AucRocMetric.withCurve()); // <8>
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()); // <8>
|
||||||
// end::evaluate-data-frame-evaluation-outlierdetection
|
// end::evaluate-data-frame-evaluation-outlierdetection
|
||||||
|
|
||||||
// tag::evaluate-data-frame-request
|
// tag::evaluate-data-frame-request
|
||||||
|
@ -3386,7 +3387,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
// end::evaluate-data-frame-response
|
// end::evaluate-data-frame-response
|
||||||
|
|
||||||
// tag::evaluate-data-frame-results-outlierdetection
|
// tag::evaluate-data-frame-results-outlierdetection
|
||||||
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
|
||||||
|
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); // <1>
|
||||||
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
|
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
|
||||||
|
|
||||||
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
|
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
|
||||||
|
@ -3395,7 +3397,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
|
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
|
||||||
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
|
containsInAnyOrder(
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME,
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME,
|
||||||
|
ConfusionMatrixMetric.NAME,
|
||||||
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME));
|
||||||
assertThat(precision, closeTo(0.6, 1e-9));
|
assertThat(precision, closeTo(0.6, 1e-9));
|
||||||
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
|
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
|
||||||
assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4
|
assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4
|
||||||
|
@ -3409,10 +3415,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
new OutlierDetection(
|
new OutlierDetection(
|
||||||
"label",
|
"label",
|
||||||
"p",
|
"p",
|
||||||
PrecisionMetric.at(0.4, 0.5, 0.6),
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
|
||||||
RecallMetric.at(0.5, 0.7),
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
|
||||||
ConfusionMatrixMetric.at(0.5),
|
ConfusionMatrixMetric.at(0.5),
|
||||||
AucRocMetric.withCurve()));
|
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
|
||||||
|
|
||||||
// tag::evaluate-data-frame-execute-listener
|
// tag::evaluate-data-frame-execute-listener
|
||||||
ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<EvaluateDataFrameResponse>() {
|
ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<EvaluateDataFrameResponse>() {
|
||||||
|
@ -3452,21 +3458,39 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
.startObject("predicted_class")
|
.startObject("predicted_class")
|
||||||
.field("type", "keyword")
|
.field("type", "keyword")
|
||||||
.endObject()
|
.endObject()
|
||||||
|
.startObject("ml.top_classes")
|
||||||
|
.field("type", "nested")
|
||||||
|
.endObject()
|
||||||
.endObject()
|
.endObject()
|
||||||
.endObject());
|
.endObject());
|
||||||
|
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
|
||||||
|
return new IndexRequest()
|
||||||
|
.source(XContentType.JSON,
|
||||||
|
"actual_class", actualClass,
|
||||||
|
"predicted_class", predictedClass,
|
||||||
|
"ml.top_classes", Arrays.asList(
|
||||||
|
new HashMap<String, Object>() {{
|
||||||
|
put("class_name", predictedClass);
|
||||||
|
put("class_probability", p);
|
||||||
|
}},
|
||||||
|
new HashMap<String, Object>() {{
|
||||||
|
put("class_name", "other");
|
||||||
|
put("class_probability", 1 - p);
|
||||||
|
}}));
|
||||||
|
};
|
||||||
BulkRequest bulkRequest =
|
BulkRequest bulkRequest =
|
||||||
new BulkRequest(indexName)
|
new BulkRequest(indexName)
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
|
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
|
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
|
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
|
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
|
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
|
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
|
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
|
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
|
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
|
||||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
|
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
|
||||||
RestHighLevelClient client = highLevelClient();
|
RestHighLevelClient client = highLevelClient();
|
||||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||||
|
@ -3476,11 +3500,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
|
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
|
||||||
"actual_class", // <2>
|
"actual_class", // <2>
|
||||||
"predicted_class", // <3>
|
"predicted_class", // <3>
|
||||||
// Evaluation metrics // <4>
|
"ml.top_classes", // <4>
|
||||||
new AccuracyMetric(), // <5>
|
// Evaluation metrics // <5>
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6>
|
new AccuracyMetric(), // <6>
|
||||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7>
|
new PrecisionMetric(), // <7>
|
||||||
new MulticlassConfusionMatrixMetric(3)); // <8>
|
new RecallMetric(), // <8>
|
||||||
|
new MulticlassConfusionMatrixMetric(3), // <9>
|
||||||
|
AucRocMetric.forClass("cat")); // <10>
|
||||||
// end::evaluate-data-frame-evaluation-classification
|
// end::evaluate-data-frame-evaluation-classification
|
||||||
|
|
||||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||||
|
@ -3490,12 +3516,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
|
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
|
||||||
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
|
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
|
||||||
|
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <3>
|
||||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
|
|
||||||
double precision = precisionResult.getAvgPrecision(); // <4>
|
double precision = precisionResult.getAvgPrecision(); // <4>
|
||||||
|
|
||||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
RecallMetric.Result recallResult = response.getMetricByName(RecallMetric.NAME); // <5>
|
||||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
|
|
||||||
double recall = recallResult.getAvgRecall(); // <6>
|
double recall = recallResult.getAvgRecall(); // <6>
|
||||||
|
|
||||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||||
|
@ -3503,19 +3527,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
|
|
||||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
|
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
|
||||||
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
|
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
|
||||||
|
|
||||||
|
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
|
||||||
|
double aucRocScore = aucRocResult.getScore(); // <11>
|
||||||
|
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
|
||||||
// end::evaluate-data-frame-results-classification
|
// end::evaluate-data-frame-results-classification
|
||||||
|
|
||||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||||
assertThat(accuracy, equalTo(0.6));
|
assertThat(accuracy, equalTo(0.6));
|
||||||
|
|
||||||
assertThat(
|
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
||||||
precisionResult.getMetricName(),
|
|
||||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
|
||||||
assertThat(precision, equalTo(0.675));
|
assertThat(precision, equalTo(0.675));
|
||||||
|
|
||||||
assertThat(
|
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
||||||
recallResult.getMetricName(),
|
|
||||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
|
||||||
assertThat(recall, equalTo(0.45));
|
assertThat(recall, equalTo(0.45));
|
||||||
|
|
||||||
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||||
|
@ -3539,6 +3563,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
||||||
0L))));
|
0L))));
|
||||||
assertThat(otherClassesCount, equalTo(0L));
|
assertThat(otherClassesCount, equalTo(0L));
|
||||||
|
|
||||||
|
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||||
|
assertThat(aucRocScore, equalTo(0.2625));
|
||||||
|
assertThat(aucRocDocCount, equalTo(5L));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetricResultTests;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetricResultTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests;
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
|
@ -16,7 +16,7 @@
|
||||||
* specific language governing permissions and limitations
|
* specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
@ -31,6 +31,7 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
|
||||||
public static AucRocMetric.Result randomResult() {
|
public static AucRocMetric.Result randomResult() {
|
||||||
return new AucRocMetric.Result(
|
return new AucRocMetric.Result(
|
||||||
randomDouble(),
|
randomDouble(),
|
||||||
|
randomLong(),
|
||||||
Stream
|
Stream
|
||||||
.generate(AucRocMetricAucRocPointTests::randomPoint)
|
.generate(AucRocMetricAucRocPointTests::randomPoint)
|
||||||
.limit(randomIntBetween(1, 10))
|
.limit(randomIntBetween(1, 10))
|
|
@ -0,0 +1,55 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AucRocMetric createRandom() {
|
||||||
|
return new AucRocMetric(
|
||||||
|
randomAlphaOfLengthBetween(1, 10),
|
||||||
|
randomBoolean() ? randomBoolean() : null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AucRocMetric createTestInstance() {
|
||||||
|
return createRandom();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return AucRocMetric.fromXContent(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,11 +40,16 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
||||||
List<EvaluationMetric> metrics =
|
List<EvaluationMetric> metrics =
|
||||||
randomSubsetOf(
|
randomSubsetOf(
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
AucRocMetricTests.createRandom(),
|
||||||
AccuracyMetricTests.createRandom(),
|
AccuracyMetricTests.createRandom(),
|
||||||
PrecisionMetricTests.createRandom(),
|
PrecisionMetricTests.createRandom(),
|
||||||
RecallMetricTests.createRandom(),
|
RecallMetricTests.createRandom(),
|
||||||
MulticlassConfusionMatrixMetricTests.createRandom()));
|
MulticlassConfusionMatrixMetricTests.createRandom()));
|
||||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
return new Classification(
|
||||||
|
randomAlphaOfLength(10),
|
||||||
|
randomBoolean() ? randomAlphaOfLength(10) : null,
|
||||||
|
randomBoolean() ? randomAlphaOfLength(10) : null,
|
||||||
|
metrics.isEmpty() ? null : metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AucRocMetric createRandom() {
|
||||||
|
return new AucRocMetric(randomBoolean() ? randomBoolean() : null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AucRocMetric createTestInstance() {
|
||||||
|
return createRandom();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return AucRocMetric.fromXContent(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,7 +40,7 @@ public class OutlierDetectionTests extends AbstractXContentTestCase<OutlierDetec
|
||||||
public static OutlierDetection createRandom() {
|
public static OutlierDetection createRandom() {
|
||||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
metrics.add(new AucRocMetric(randomBoolean()));
|
metrics.add(AucRocMetricTests.createRandom());
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1,
|
metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1,
|
||||||
|
|
|
@ -51,11 +51,13 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
|
||||||
<1> Constructing a new evaluation
|
<1> Constructing a new evaluation
|
||||||
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
|
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
|
||||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
|
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
|
||||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
<4> Name of the field in the index. Its value denotes the array of top classes. Must be nested.
|
||||||
<5> Accuracy
|
<5> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||||
<6> Precision
|
<6> Accuracy
|
||||||
<7> Recall
|
<7> Precision
|
||||||
<8> Multiclass confusion matrix of size 3
|
<8> Recall
|
||||||
|
<9> Multiclass confusion matrix of size 3
|
||||||
|
<10> {wikipedia}/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated for class "cat" treated as positive and the rest as negative
|
||||||
|
|
||||||
===== Regression
|
===== Regression
|
||||||
|
|
||||||
|
@ -115,6 +117,9 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
|
||||||
<7> Fetching multiclass confusion matrix metric by name
|
<7> Fetching multiclass confusion matrix metric by name
|
||||||
<8> Fetching the contents of the confusion matrix
|
<8> Fetching the contents of the confusion matrix
|
||||||
<9> Fetching the number of classes that were not included in the matrix
|
<9> Fetching the number of classes that were not included in the matrix
|
||||||
|
<10> Fetching AucRoc metric by name
|
||||||
|
<11> Fetching the actual AucRoc score
|
||||||
|
<12> Fetching the number of documents that were used in order to calculate AucRoc score
|
||||||
|
|
||||||
===== Regression
|
===== Regression
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue