This commit is contained in:
parent
0860746bf2
commit
d677a2b8ee
|
@ -19,13 +19,13 @@
|
|||
package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
|
@ -63,34 +63,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
// Evaluation metrics
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)),
|
||||
AucRocMetric::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)),
|
||||
PrecisionMetric::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)),
|
||||
RecallMetric::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
|
||||
ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
|
||||
AucRocMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
||||
AccuracyMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
|
||||
new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
|
||||
PrecisionMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
|
||||
new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
|
||||
RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
||||
|
@ -114,34 +122,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
// Evaluation metrics results
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)),
|
||||
AucRocMetric.Result::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)),
|
||||
PrecisionMetric.Result::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)),
|
||||
RecallMetric.Result::fromXContent),
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
|
||||
ConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
|
||||
AucRocMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
||||
AccuracyMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
|
||||
new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
|
||||
PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
|
||||
new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
|
||||
RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
|
||||
* The ROC curve is a plot of the TPR (true positive rate) against
|
||||
* the FPR (false positive rate) over a varying threshold.
|
||||
*/
|
||||
public class AucRocMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "auc_roc";
|
||||
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
|
||||
}
|
||||
|
||||
public static AucRocMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public static AucRocMetric forClass(String className) {
|
||||
return new AucRocMetric(className, false);
|
||||
}
|
||||
|
||||
public static AucRocMetric forClassWithCurve(String className) {
|
||||
return new AucRocMetric(className, true);
|
||||
}
|
||||
|
||||
private final String className;
|
||||
private final Boolean includeCurve;
|
||||
|
||||
public AucRocMetric(String className, Boolean includeCurve) {
|
||||
this.className = Objects.requireNonNull(className);
|
||||
this.includeCurve = includeCurve;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
if (includeCurve != null) {
|
||||
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AucRocMetric that = (AucRocMetric) o;
|
||||
return Objects.equals(className, that.className)
|
||||
&& Objects.equals(includeCurve, that.includeCurve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, includeCurve);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ParseField SCORE = new ParseField("score");
|
||||
private static final ParseField DOC_COUNT = new ParseField("doc_count");
|
||||
private static final ParseField CURVE = new ParseField("curve");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), SCORE);
|
||||
PARSER.declareLong(constructorArg(), DOC_COUNT);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
|
||||
}
|
||||
|
||||
private final double score;
|
||||
private final long docCount;
|
||||
private final List<AucRocPoint> curve;
|
||||
|
||||
public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
|
||||
this.score = score;
|
||||
this.docCount = docCount;
|
||||
this.curve = curve;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public long getDocCount() {
|
||||
return docCount;
|
||||
}
|
||||
|
||||
public List<AucRocPoint> getCurve() {
|
||||
return curve == null ? null : Collections.unmodifiableList(curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SCORE.getPreferredName(), score);
|
||||
builder.field(DOC_COUNT.getPreferredName(), docCount);
|
||||
if (curve != null && curve.isEmpty() == false) {
|
||||
builder.field(CURVE.getPreferredName(), curve);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return score == that.score
|
||||
&& docCount == that.docCount
|
||||
&& Objects.equals(curve, that.curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(score, docCount, curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class AucRocPoint implements ToXContentObject {
|
||||
|
||||
public static AucRocPoint fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ParseField TPR = new ParseField("tpr");
|
||||
private static final ParseField FPR = new ParseField("fpr");
|
||||
private static final ParseField THRESHOLD = new ParseField("threshold");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"auc_roc_point",
|
||||
true,
|
||||
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), TPR);
|
||||
PARSER.declareDouble(constructorArg(), FPR);
|
||||
PARSER.declareDouble(constructorArg(), THRESHOLD);
|
||||
}
|
||||
|
||||
private final double tpr;
|
||||
private final double fpr;
|
||||
private final double threshold;
|
||||
|
||||
public AucRocPoint(double tpr, double fpr, double threshold) {
|
||||
this.tpr = tpr;
|
||||
this.fpr = fpr;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
public double getTruePositiveRate() {
|
||||
return tpr;
|
||||
}
|
||||
|
||||
public double getFalsePositiveRate() {
|
||||
return fpr;
|
||||
}
|
||||
|
||||
public double getThreshold() {
|
||||
return threshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder
|
||||
.startObject()
|
||||
.field(TPR.getPreferredName(), tpr)
|
||||
.field(FPR.getPreferredName(), fpr)
|
||||
.field(THRESHOLD.getPreferredName(), threshold)
|
||||
.endObject();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AucRocPoint that = (AucRocPoint) o;
|
||||
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(tpr, fpr, threshold);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -45,15 +45,20 @@ public class Classification implements Evaluation {
|
|||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
|
||||
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
NAME,
|
||||
true,
|
||||
a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List<EvaluationMetric>) a[3]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareString(optionalConstructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_FIELD);
|
||||
PARSER.declareNamedObjects(
|
||||
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
|
||||
}
|
||||
|
@ -64,32 +69,44 @@ public class Classification implements Evaluation {
|
|||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
/**
|
||||
* The field containing the array of top classes
|
||||
*/
|
||||
private final String topClassesField;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Classification(String actualField, String predictedField) {
|
||||
this(actualField, predictedField, (List<EvaluationMetric>)null);
|
||||
public Classification(String actualField,
|
||||
String predictedField,
|
||||
String topClassesField) {
|
||||
this(actualField, predictedField, topClassesField, (List<EvaluationMetric>)null);
|
||||
}
|
||||
|
||||
public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
|
||||
this(actualField, predictedField, Arrays.asList(metrics));
|
||||
public Classification(String actualField,
|
||||
String predictedField,
|
||||
String topClassesField,
|
||||
EvaluationMetric... metrics) {
|
||||
this(actualField, predictedField, topClassesField, Arrays.asList(metrics));
|
||||
}
|
||||
|
||||
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
public Classification(String actualField,
|
||||
@Nullable String predictedField,
|
||||
@Nullable String topClassesField,
|
||||
@Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = Objects.requireNonNull(actualField);
|
||||
this.predictedField = Objects.requireNonNull(predictedField);
|
||||
this.predictedField = predictedField;
|
||||
this.topClassesField = topClassesField;
|
||||
if (metrics != null) {
|
||||
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
|
||||
}
|
||||
|
@ -105,8 +122,12 @@ public class Classification implements Evaluation {
|
|||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
if (predictedField != null) {
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
}
|
||||
if (topClassesField != null) {
|
||||
builder.field(TOP_CLASSES_FIELD.getPreferredName(), topClassesField);
|
||||
}
|
||||
if (metrics != null) {
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
|
@ -126,11 +147,12 @@ public class Classification implements Evaluation {
|
|||
Classification that = (Classification) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
&& Objects.equals(that.topClassesField, this.topClassesField)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
return Objects.hash(actualField, predictedField, topClassesField, metrics);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,21 +19,14 @@
|
|||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
|
@ -49,7 +42,7 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0]));
|
||||
new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((Boolean) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
|
||||
|
@ -63,18 +56,20 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
return new AucRocMetric(true);
|
||||
}
|
||||
|
||||
private final boolean includeCurve;
|
||||
private final Boolean includeCurve;
|
||||
|
||||
public AucRocMetric(Boolean includeCurve) {
|
||||
this.includeCurve = includeCurve == null ? false : includeCurve;
|
||||
this.includeCurve = includeCurve;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
return builder
|
||||
.startObject()
|
||||
.field(INCLUDE_CURVE.getPreferredName(), includeCurve)
|
||||
.endObject();
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (includeCurve != null) {
|
||||
builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -94,148 +89,4 @@ public class AucRocMetric implements EvaluationMetric {
|
|||
public int hashCode() {
|
||||
return Objects.hash(includeCurve);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ParseField SCORE = new ParseField("score");
|
||||
private static final ParseField CURVE = new ParseField("curve");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), SCORE);
|
||||
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
|
||||
}
|
||||
|
||||
private final double score;
|
||||
private final List<AucRocPoint> curve;
|
||||
|
||||
public Result(double score, @Nullable List<AucRocPoint> curve) {
|
||||
this.score = score;
|
||||
this.curve = curve;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public List<AucRocPoint> getCurve() {
|
||||
return curve == null ? null : Collections.unmodifiableList(curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(SCORE.getPreferredName(), score);
|
||||
if (curve != null && curve.isEmpty() == false) {
|
||||
builder.field(CURVE.getPreferredName(), curve);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(score, that.score)
|
||||
&& Objects.equals(curve, that.curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(score, curve);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class AucRocPoint implements ToXContentObject {
|
||||
|
||||
public static AucRocPoint fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ParseField TPR = new ParseField("tpr");
|
||||
private static final ParseField FPR = new ParseField("fpr");
|
||||
private static final ParseField THRESHOLD = new ParseField("threshold");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"auc_roc_point",
|
||||
true,
|
||||
args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), TPR);
|
||||
PARSER.declareDouble(constructorArg(), FPR);
|
||||
PARSER.declareDouble(constructorArg(), THRESHOLD);
|
||||
}
|
||||
|
||||
private final double tpr;
|
||||
private final double fpr;
|
||||
private final double threshold;
|
||||
|
||||
public AucRocPoint(double tpr, double fpr, double threshold) {
|
||||
this.tpr = tpr;
|
||||
this.fpr = fpr;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
public double getTruePositiveRate() {
|
||||
return tpr;
|
||||
}
|
||||
|
||||
public double getFalsePositiveRate() {
|
||||
return fpr;
|
||||
}
|
||||
|
||||
public double getThreshold() {
|
||||
return threshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return builder
|
||||
.startObject()
|
||||
.field(TPR.getPreferredName(), tpr)
|
||||
.field(FPR.getPreferredName(), fpr)
|
||||
.field(THRESHOLD.getPreferredName(), threshold)
|
||||
.endObject();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
AucRocPoint that = (AucRocPoint) o;
|
||||
return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(tpr, fpr, threshold);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Strings.toString(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,13 +138,13 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
|||
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
|
@ -1774,15 +1774,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
new OutlierDetection(
|
||||
actualField,
|
||||
probabilityField,
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve()));
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
|
||||
ConfusionMatrixMetric.at(0.5),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
|
||||
|
||||
PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
|
||||
evaluateDataFrameResponse.getMetricByName(
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME);
|
||||
assertThat(
|
||||
precisionResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME));
|
||||
// Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9)
|
||||
assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9));
|
||||
// Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9)
|
||||
|
@ -1791,8 +1798,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9));
|
||||
assertNull(precisionResult.getScoreByThreshold("0.1"));
|
||||
|
||||
RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
|
||||
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result recallResult =
|
||||
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME);
|
||||
assertThat(
|
||||
recallResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME));
|
||||
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
||||
assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9));
|
||||
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
||||
|
@ -1808,7 +1818,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
|
||||
assertNull(confusionMatrixResult.getScoreByThreshold("0.1"));
|
||||
|
||||
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
||||
AucRocMetric.Result aucRocResult =
|
||||
evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||
assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9));
|
||||
assertNotNull(aucRocResult.getCurve());
|
||||
|
@ -1920,24 +1931,40 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
createIndex(indexName, mappingForClassification());
|
||||
BulkRequest regressionBulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForClassification(indexName, "cat", "cat"))
|
||||
.add(docForClassification(indexName, "cat", "cat"))
|
||||
.add(docForClassification(indexName, "cat", "cat"))
|
||||
.add(docForClassification(indexName, "cat", "dog"))
|
||||
.add(docForClassification(indexName, "cat", "fish"))
|
||||
.add(docForClassification(indexName, "dog", "cat"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "dog", "dog"))
|
||||
.add(docForClassification(indexName, "ant", "cat"));
|
||||
.add(docForClassification(indexName, "cat", "cat", 0.9))
|
||||
.add(docForClassification(indexName, "cat", "cat", 0.85))
|
||||
.add(docForClassification(indexName, "cat", "cat", 0.95))
|
||||
.add(docForClassification(indexName, "cat", "dog", 0.4))
|
||||
.add(docForClassification(indexName, "cat", "fish", 0.35))
|
||||
.add(docForClassification(indexName, "dog", "cat", 0.5))
|
||||
.add(docForClassification(indexName, "dog", "dog", 0.4))
|
||||
.add(docForClassification(indexName, "dog", "dog", 0.35))
|
||||
.add(docForClassification(indexName, "dog", "dog", 0.6))
|
||||
.add(docForClassification(indexName, "ant", "cat", 0.1));
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||
|
||||
{ // AucRoc
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName, null, new Classification(actualClassField, null, topClassesField, AucRocMetric.forClassWithCurve("cat")));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||
assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
|
||||
assertThat(aucRocResult.getDocCount(), equalTo(5L));
|
||||
assertNotNull(aucRocResult.getCurve());
|
||||
}
|
||||
{ // Accuracy
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
|
||||
indexName, null, new Classification(actualClassField, predictedClassField, null, new AccuracyMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
|
@ -1961,65 +1988,47 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
{ // Precision
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(
|
||||
actualClassField,
|
||||
predictedClassField,
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
|
||||
indexName, null, new Classification(actualClassField, predictedClassField, null, new PrecisionMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
||||
evaluateDataFrameResponse.getMetricByName(
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
|
||||
assertThat(
|
||||
precisionResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
||||
PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
|
||||
new PrecisionMetric.PerClassResult("cat", 0.6),
|
||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
|
||||
new PrecisionMetric.PerClassResult("dog", 0.75))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
|
||||
}
|
||||
{ // Recall
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(
|
||||
actualClassField,
|
||||
predictedClassField,
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
|
||||
indexName, null, new Classification(actualClassField, predictedClassField, null, new RecallMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
||||
evaluateDataFrameResponse.getMetricByName(
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
|
||||
assertThat(
|
||||
recallResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
||||
RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
|
||||
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
|
||||
new RecallMetric.PerClassResult("cat", 0.6),
|
||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
|
||||
new RecallMetric.PerClassResult("dog", 0.75),
|
||||
// no examples labeled as "ant" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
|
||||
new RecallMetric.PerClassResult("ant", 0.0))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(0.45));
|
||||
}
|
||||
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
||||
|
@ -2027,7 +2036,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
|
||||
new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
|
@ -2072,7 +2081,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
|
||||
new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric(2)));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
|
@ -2146,6 +2155,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
private static final String actualClassField = "actual_class";
|
||||
private static final String predictedClassField = "predicted_class";
|
||||
private static final String topClassesField = "top_classes";
|
||||
|
||||
private static XContentBuilder mappingForClassification() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
|
@ -2156,14 +2166,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.startObject(predictedClassField)
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.startObject(topClassesField)
|
||||
.field("type", "nested")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
|
||||
private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
|
||||
.source(XContentType.JSON,
|
||||
actualClassField, actualClass,
|
||||
predictedClassField, predictedClass,
|
||||
topClassesField, Arrays.asList(
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", predictedClass);
|
||||
put("class_probability", p);
|
||||
}},
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", "other");
|
||||
put("class_probability", 1 - p);
|
||||
}}));
|
||||
}
|
||||
|
||||
private static final String actualRegression = "regression_actual";
|
||||
|
|
|
@ -59,13 +59,13 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction;
|
|||
import org.elasticsearch.client.indexlifecycle.WaitForSnapshotAction;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
|
@ -707,7 +707,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(73, namedXContents.size());
|
||||
assertEquals(75, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -756,35 +756,39 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(
|
||||
registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME),
|
||||
registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME),
|
||||
registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME),
|
||||
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, AucRocMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, PrecisionMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, HuberMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(
|
||||
registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME),
|
||||
registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME),
|
||||
registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME),
|
||||
registeredMetricName(
|
||||
OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME),
|
||||
registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, AucRocMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, PrecisionMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||
|
|
|
@ -156,15 +156,15 @@ import org.elasticsearch.client.ml.dataframe.Regression;
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
|
@ -201,6 +201,7 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
|
|||
import org.elasticsearch.client.ml.job.results.Influencer;
|
||||
import org.elasticsearch.client.ml.job.results.OverallBucket;
|
||||
import org.elasticsearch.client.ml.job.stats.JobStats;
|
||||
import org.elasticsearch.common.TriFunction;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
|
@ -3326,7 +3327,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
30, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
public void testEvaluateDataFrame() throws Exception {
|
||||
public void testEvaluateDataFrame_OutlierDetection() throws Exception {
|
||||
String indexName = "evaluate-test-index";
|
||||
CreateIndexRequest createIndexRequest =
|
||||
new CreateIndexRequest(indexName)
|
||||
|
@ -3363,10 +3364,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
"label", // <2>
|
||||
"p", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
|
||||
RecallMetric.at(0.5, 0.7), // <6>
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), // <6>
|
||||
ConfusionMatrixMetric.at(0.5), // <7>
|
||||
AucRocMetric.withCurve()); // <8>
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()); // <8>
|
||||
// end::evaluate-data-frame-evaluation-outlierdetection
|
||||
|
||||
// tag::evaluate-data-frame-request
|
||||
|
@ -3386,7 +3387,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
// end::evaluate-data-frame-response
|
||||
|
||||
// tag::evaluate-data-frame-results-outlierdetection
|
||||
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
|
||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); // <1>
|
||||
double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
|
||||
|
||||
ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
|
||||
|
@ -3395,7 +3397,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
assertThat(
|
||||
metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
|
||||
containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
|
||||
containsInAnyOrder(
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME,
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME));
|
||||
assertThat(precision, closeTo(0.6, 1e-9));
|
||||
assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9
|
||||
assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4
|
||||
|
@ -3409,10 +3415,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
new OutlierDetection(
|
||||
"label",
|
||||
"p",
|
||||
PrecisionMetric.at(0.4, 0.5, 0.6),
|
||||
RecallMetric.at(0.5, 0.7),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
|
||||
ConfusionMatrixMetric.at(0.5),
|
||||
AucRocMetric.withCurve()));
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
|
||||
|
||||
// tag::evaluate-data-frame-execute-listener
|
||||
ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<EvaluateDataFrameResponse>() {
|
||||
|
@ -3452,21 +3458,39 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.startObject("predicted_class")
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.startObject("ml.top_classes")
|
||||
.field("type", "nested")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject());
|
||||
TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
|
||||
return new IndexRequest()
|
||||
.source(XContentType.JSON,
|
||||
"actual_class", actualClass,
|
||||
"predicted_class", predictedClass,
|
||||
"ml.top_classes", Arrays.asList(
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", predictedClass);
|
||||
put("class_probability", p);
|
||||
}},
|
||||
new HashMap<String, Object>() {{
|
||||
put("class_name", "other");
|
||||
put("class_probability", 1 - p);
|
||||
}}));
|
||||
};
|
||||
BulkRequest bulkRequest =
|
||||
new BulkRequest(indexName)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
|
||||
.add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
|
||||
.add(indexRequest.apply("cat", "cat", 0.9)) // #0
|
||||
.add(indexRequest.apply("cat", "cat", 0.9)) // #1
|
||||
.add(indexRequest.apply("cat", "cat", 0.9)) // #2
|
||||
.add(indexRequest.apply("cat", "dog", 0.9)) // #3
|
||||
.add(indexRequest.apply("cat", "fox", 0.9)) // #4
|
||||
.add(indexRequest.apply("dog", "cat", 0.9)) // #5
|
||||
.add(indexRequest.apply("dog", "dog", 0.9)) // #6
|
||||
.add(indexRequest.apply("dog", "dog", 0.9)) // #7
|
||||
.add(indexRequest.apply("dog", "dog", 0.9)) // #8
|
||||
.add(indexRequest.apply("ant", "cat", 0.9)); // #9
|
||||
RestHighLevelClient client = highLevelClient();
|
||||
client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
client.bulk(bulkRequest, RequestOptions.DEFAULT);
|
||||
|
@ -3476,11 +3500,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
|
||||
"actual_class", // <2>
|
||||
"predicted_class", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
new AccuracyMetric(), // <5>
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6>
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7>
|
||||
new MulticlassConfusionMatrixMetric(3)); // <8>
|
||||
"ml.top_classes", // <4>
|
||||
// Evaluation metrics // <5>
|
||||
new AccuracyMetric(), // <6>
|
||||
new PrecisionMetric(), // <7>
|
||||
new RecallMetric(), // <8>
|
||||
new MulticlassConfusionMatrixMetric(3), // <9>
|
||||
AucRocMetric.forClass("cat")); // <10>
|
||||
// end::evaluate-data-frame-evaluation-classification
|
||||
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||
|
@ -3490,12 +3516,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
|
||||
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
|
||||
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
|
||||
PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <3>
|
||||
double precision = precisionResult.getAvgPrecision(); // <4>
|
||||
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
|
||||
RecallMetric.Result recallResult = response.getMetricByName(RecallMetric.NAME); // <5>
|
||||
double recall = recallResult.getAvgRecall(); // <6>
|
||||
|
||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
|
@ -3503,19 +3527,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
|
||||
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
|
||||
|
||||
AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
|
||||
double aucRocScore = aucRocResult.getScore(); // <11>
|
||||
Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
|
||||
// end::evaluate-data-frame-results-classification
|
||||
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||
assertThat(accuracy, equalTo(0.6));
|
||||
|
||||
assertThat(
|
||||
precisionResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
||||
assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
||||
assertThat(precision, equalTo(0.675));
|
||||
|
||||
assertThat(
|
||||
recallResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
||||
assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
||||
assertThat(recall, equalTo(0.45));
|
||||
|
||||
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
|
@ -3539,6 +3563,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
||||
0L))));
|
||||
assertThat(otherClassesCount, equalTo(0L));
|
||||
|
||||
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
||||
assertThat(aucRocScore, equalTo(0.2625));
|
||||
assertThat(aucRocDocCount, equalTo(5L));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
|
|||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetricResultTests;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetricResultTests;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
|
@ -16,7 +16,7 @@
|
|||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
@ -31,6 +31,7 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
|
|||
public static AucRocMetric.Result randomResult() {
|
||||
return new AucRocMetric.Result(
|
||||
randomDouble(),
|
||||
randomLong(),
|
||||
Stream
|
||||
.generate(AucRocMetricAucRocPointTests::randomPoint)
|
||||
.limit(randomIntBetween(1, 10))
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static AucRocMetric createRandom() {
|
||||
return new AucRocMetric(
|
||||
randomAlphaOfLengthBetween(1, 10),
|
||||
randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AucRocMetric createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return AucRocMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -40,11 +40,16 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
List<EvaluationMetric> metrics =
|
||||
randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AucRocMetricTests.createRandom(),
|
||||
AccuracyMetricTests.createRandom(),
|
||||
PrecisionMetricTests.createRandom(),
|
||||
RecallMetricTests.createRandom(),
|
||||
MulticlassConfusionMatrixMetricTests.createRandom()));
|
||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
return new Classification(
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? randomAlphaOfLength(10) : null,
|
||||
randomBoolean() ? randomAlphaOfLength(10) : null,
|
||||
metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static AucRocMetric createRandom() {
|
||||
return new AucRocMetric(randomBoolean() ? randomBoolean() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AucRocMetric createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return AucRocMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -40,7 +40,7 @@ public class OutlierDetectionTests extends AbstractXContentTestCase<OutlierDetec
|
|||
public static OutlierDetection createRandom() {
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new AucRocMetric(randomBoolean()));
|
||||
metrics.add(AucRocMetricTests.createRandom());
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1,
|
||||
|
|
|
@ -51,11 +51,13 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
|
|||
<1> Constructing a new evaluation
|
||||
<2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
|
||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> Accuracy
|
||||
<6> Precision
|
||||
<7> Recall
|
||||
<8> Multiclass confusion matrix of size 3
|
||||
<4> Name of the field in the index. Its value denotes the array of top classes. Must be nested.
|
||||
<5> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<6> Accuracy
|
||||
<7> Precision
|
||||
<8> Recall
|
||||
<9> Multiclass confusion matrix of size 3
|
||||
<10> {wikipedia}/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated for class "cat" treated as positive and the rest as negative
|
||||
|
||||
===== Regression
|
||||
|
||||
|
@ -115,6 +117,9 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
|
|||
<7> Fetching multiclass confusion matrix metric by name
|
||||
<8> Fetching the contents of the confusion matrix
|
||||
<9> Fetching the number of classes that were not included in the matrix
|
||||
<10> Fetching AucRoc metric by name
|
||||
<11> Fetching the actual AucRoc score
|
||||
<12> Fetching the number of documents that were used in order to calculate AucRoc score
|
||||
|
||||
===== Regression
|
||||
|
||||
|
|
Loading…
Reference in New Issue