[7.x] Implement precision and recall metrics for classification evaluation (#49671) (#50378)

This commit is contained in:
Przemysław Witek 2019-12-19 18:55:05 +01:00 committed by GitHub
parent 903305284d
commit cc4bc797f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 2493 additions and 388 deletions

View File

@ -35,6 +35,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
public class EvaluateDataFrameResponse implements ToXContentObject {
@ -47,7 +48,7 @@ public class EvaluateDataFrameResponse implements ToXContentObject {
ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
String evaluationName = parser.currentName();
parser.nextToken();
Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric);
Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, p -> parseMetric(evaluationName, p));
List<EvaluationMetric.Result> knownMetrics =
metrics.values().stream()
.filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}.
@ -56,10 +57,10 @@ public class EvaluateDataFrameResponse implements ToXContentObject {
return new EvaluateDataFrameResponse(evaluationName, knownMetrics);
}
private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException {
private static EvaluationMetric.Result parseMetric(String evaluationName, XContentParser parser) throws IOException {
String metricName = parser.currentName();
try {
return parser.namedObject(EvaluationMetric.Result.class, metricName, null);
return parser.namedObject(EvaluationMetric.Result.class, registeredMetricName(evaluationName, metricName), null);
} catch (NamedObjectNotFoundException e) {
parser.skipChildren();
// Metric name not recognized. Return {@code null} value here and filter it out later.

View File

@ -20,24 +20,36 @@ package org.elasticsearch.client.ml.dataframe.evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import java.util.Arrays;
import java.util.List;
public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
/**
* Constructs the name under which a metric (or metric result) is registered.
* The name is prefixed with evaluation name so that registered names are unique.
*
* @param evaluationName name of the evaluation
* @param metricName name of the metric
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
*/
public static String registeredMetricName(String evaluationName, String metricName) {
return evaluationName + "." + metricName;
}
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
return Arrays.asList(
@ -47,39 +59,91 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
// Evaluation metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(MulticlassConfusionMatrixMetric.NAME),
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
AucRocMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
PrecisionMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
RecallMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
ConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
AccuracyMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
MulticlassConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
MeanSquaredErrorMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent),
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
RSquaredMetric::fromXContent),
// Evaluation metrics results
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
AucRocMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(MulticlassConfusionMatrixMetric.NAME),
new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
PrecisionMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
ConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
AccuracyMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
MulticlassConfusionMatrixMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
MeanSquaredErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
RSquaredMetric.Result::fromXContent)
);
}
}

View File

@ -32,6 +32,10 @@ import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Evaluation of classification results.
*/
@ -48,10 +52,10 @@ public class Classification implements Evaluation {
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
}
public static Classification fromXContent(XContentParser parser) {

View File

@ -0,0 +1,201 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* {@link PrecisionMetric} is a metric that answers the question:
* "What fraction of documents classified as X actually belongs to X?"
* for any given class X
*
* equation: precision(X) = TP(X) / (TP(X) + FP(X))
* where: TP(X) - number of true positives wrt X
* FP(X) - number of false positives wrt X
*/
public class PrecisionMetric implements EvaluationMetric {
public static final String NAME = "precision";
private static final ObjectParser<PrecisionMetric, Void> PARSER = new ObjectParser<>(NAME, true, PrecisionMetric::new);
public static PrecisionMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public PrecisionMetric() {}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hashCode(NAME);
}
public static class Result implements EvaluationMetric.Result {
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField AVG_PRECISION = new ParseField("avg_precision");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("precision_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), AVG_PRECISION);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Average of per-class precisions. */
private final double avgPrecision;
public Result(List<PerClassResult> classes, double avgPrecision) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.avgPrecision = avgPrecision;
}
@Override
public String getMetricName() {
return NAME;
}
public List<PerClassResult> getClasses() {
return classes;
}
public double getAvgPrecision() {
return avgPrecision;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASSES.getPreferredName(), classes);
builder.field(AVG_PRECISION.getPreferredName(), avgPrecision);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.classes, that.classes)
&& this.avgPrecision == that.avgPrecision;
}
@Override
public int hashCode() {
return Objects.hash(classes, avgPrecision);
}
}
public static class PerClassResult implements ToXContentObject {
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField PRECISION = new ParseField("precision");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), PRECISION);
}
/** Name of the class. */
private final String className;
/** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */
private final double precision;
public PerClassResult(String className, double precision) {
this.className = Objects.requireNonNull(className);
this.precision = precision;
}
public String getClassName() {
return className;
}
public double getPrecision() {
return precision;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(PRECISION.getPreferredName(), precision);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.precision == that.precision;
}
@Override
public int hashCode() {
return Objects.hash(className, precision);
}
}
}

View File

@ -0,0 +1,201 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* {@link RecallMetric} is a metric that answers the question:
* "What fraction of documents belonging to X have been predicted as X by the classifier?"
* for any given class X
*
* equation: recall(X) = TP(X) / (TP(X) + FN(X))
* where: TP(X) - number of true positives wrt X
* FN(X) - number of false negatives wrt X
*/
public class RecallMetric implements EvaluationMetric {
public static final String NAME = "recall";
private static final ObjectParser<RecallMetric, Void> PARSER = new ObjectParser<>(NAME, true, RecallMetric::new);
public static RecallMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public RecallMetric() {}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hashCode(NAME);
}
public static class Result implements EvaluationMetric.Result {
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField AVG_RECALL = new ParseField("avg_recall");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("recall_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), AVG_RECALL);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Average of per-class recalls. */
private final double avgRecall;
public Result(List<PerClassResult> classes, double avgRecall) {
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
this.avgRecall = avgRecall;
}
@Override
public String getMetricName() {
return NAME;
}
public List<PerClassResult> getClasses() {
return classes;
}
public double getAvgRecall() {
return avgRecall;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASSES.getPreferredName(), classes);
builder.field(AVG_RECALL.getPreferredName(), avgRecall);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.classes, that.classes)
&& this.avgRecall == that.avgRecall;
}
@Override
public int hashCode() {
return Objects.hash(classes, avgRecall);
}
}
public static class PerClassResult implements ToXContentObject {
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField RECALL = new ParseField("recall");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), RECALL);
}
/** Name of the class. */
private final String className;
/** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */
private final double recall;
public PerClassResult(String className, double recall) {
this.className = Objects.requireNonNull(className);
this.recall = recall;
}
public String getClassName() {
return className;
}
public double getRecall() {
return recall;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(RECALL.getPreferredName(), recall);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.recall == that.recall;
}
@Override
public int hashCode() {
return Objects.hash(className, recall);
}
}
}

View File

@ -33,6 +33,10 @@ import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Evaluation of regression results.
*/
@ -49,10 +53,10 @@ public class Regression implements Evaluation {
NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
}
public static Regression fromXContent(XContentParser parser) {

View File

@ -33,6 +33,7 @@ import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -59,7 +60,8 @@ public class BinarySoftClassification implements Evaluation {
static {
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD);
PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS);
PARSER.declareNamedObjects(
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), null), METRICS);
}
public static BinarySoftClassification fromXContent(XContentParser parser) {

View File

@ -1860,6 +1860,70 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
}
{ // Precision
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName,
null,
new Classification(
actualClassField,
predictedClassField,
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
evaluateDataFrameResponse.getMetricByName(
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
assertThat(
precisionResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
}
{ // Recall
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(
indexName,
null,
new Classification(
actualClassField,
predictedClassField,
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
evaluateDataFrameResponse.getMetricByName(
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
assertThat(
recallResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
// 3 out of 5 examples labeled as "cat" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
// 3 out of 4 examples labeled as "dog" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
// no examples labeled as "ant" were classified correctly
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
assertThat(recallResult.getAvgRecall(), equalTo(0.45));
}
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
EvaluateDataFrameRequest evaluateDataFrameRequest =
new EvaluateDataFrameRequest(

View File

@ -128,6 +128,7 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
import static org.hamcrest.CoreMatchers.endsWith;
import static org.hamcrest.CoreMatchers.equalTo;
@ -688,7 +689,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(51, namedXContents.size());
assertEquals(55, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -730,26 +731,36 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names,
hasItems(AucRocMetric.NAME,
PrecisionMetric.NAME,
RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
AccuracyMetric.NAME,
MulticlassConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
hasItems(
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME),
registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME),
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME),
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names,
hasItems(AucRocMetric.NAME,
PrecisionMetric.NAME,
RecallMetric.NAME,
ConfusionMatrixMetric.NAME,
AccuracyMetric.NAME,
MulticlassConfusionMatrixMetric.NAME,
MeanSquaredErrorMetric.NAME,
RSquaredMetric.NAME));
hasItems(
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME),
registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME),
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME),
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
registeredMetricName(
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));

View File

@ -3372,7 +3372,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
"predicted_class", // <3>
// Evaluation metrics // <4>
new AccuracyMetric(), // <5>
new MulticlassConfusionMatrixMetric(3)); // <6>
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6>
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7>
new MulticlassConfusionMatrixMetric(3)); // <8>
// end::evaluate-data-frame-evaluation-classification
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@ -3382,16 +3384,34 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
double precision = precisionResult.getAvgPrecision(); // <4>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
double recall = recallResult.getAvgRecall(); // <6>
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <7>
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
// end::evaluate-data-frame-results-classification
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
assertThat(accuracy, equalTo(0.6));
assertThat(
precisionResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
assertThat(precision, equalTo(0.675));
assertThat(
recallResult.getMetricName(),
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
assertThat(recall, equalTo(0.45));
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
assertThat(
confusionMatrix,
@ -3412,7 +3432,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
4L,
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
0L))));
assertThat(otherActualClassCount, equalTo(0L));
assertThat(otherClassesCount, equalTo(0L));
}
}

View File

@ -64,6 +64,8 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<Eva
metrics = randomSubsetOf(
Arrays.asList(
AccuracyMetricResultTests.randomResult(),
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetricResultTests.randomResult(),
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetricResultTests.randomResult(),
MulticlassConfusionMatrixMetricResultTests.randomResult()));
break;
default:

View File

@ -41,6 +41,8 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
randomSubsetOf(
Arrays.asList(
AccuracyMetricTests.createRandom(),
PrecisionMetricTests.createRandom(),
RecallMetricTests.createRandom(),
MulticlassConfusionMatrixMetricTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class PrecisionMetricResultTests extends AbstractXContentTestCase<Result> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double precision = randomDoubleBetween(0.0, 1.0, true);
classes.add(new PerClassResult(classNames.get(i), precision));
}
double avgPrecision = randomDoubleBetween(0.0, 1.0, true);
return new Result(classes, avgPrecision);
}
@Override
protected Result createTestInstance() {
return randomResult();
}
@Override
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PrecisionMetricTests extends AbstractXContentTestCase<PrecisionMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
static PrecisionMetric createRandom() {
return new PrecisionMetric();
}
@Override
protected PrecisionMetric createTestInstance() {
return createRandom();
}
@Override
protected PrecisionMetric doParseInstance(XContentParser parser) throws IOException {
return PrecisionMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -0,0 +1,67 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class RecallMetricResultTests extends AbstractXContentTestCase<Result> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static Result randomResult() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double recall = randomDoubleBetween(0.0, 1.0, true);
classes.add(new PerClassResult(classNames.get(i), recall));
}
double avgRecall = randomDoubleBetween(0.0, 1.0, true);
return new Result(classes, avgRecall);
}
@Override
protected Result createTestInstance() {
return randomResult();
}
@Override
protected Result doParseInstance(XContentParser parser) throws IOException {
return Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class RecallMetricTests extends AbstractXContentTestCase<RecallMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
static RecallMetric createRandom() {
return new RecallMetric();
}
@Override
protected RecallMetric createTestInstance() {
return createRandom();
}
@Override
protected RecallMetric doParseInstance(XContentParser parser) throws IOException {
return RecallMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -53,7 +53,9 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> Accuracy
<6> Multiclass confusion matrix of size 3
<6> Precision
<7> Recall
<8> Multiclass confusion matrix of size 3
===== Regression
@ -104,9 +106,13 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
<1> Fetching accuracy metric by name
<2> Fetching the actual accuracy value
<3> Fetching multiclass confusion matrix metric by name
<4> Fetching the contents of the confusion matrix
<5> Fetching the number of classes that were not included in the matrix
<3> Fetching precision metric by name
<4> Fetching the actual precision value
<5> Fetching recall metric by name
<6> Fetching the actual recall value
<7> Fetching multiclass confusion matrix metric by name
<8> Fetching the contents of the confusion matrix
<9> Fetching the number of classes that were not included in the matrix
===== Regression
@ -118,4 +124,4 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
<1> Fetching mean squared error metric by name
<2> Fetching the actual mean squared error value
<3> Fetching R squared metric by name
<4> Fetching the actual R squared value
<4> Fetching the actual R squared value

View File

@ -79,7 +79,6 @@ import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
@ -91,6 +90,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction;
import org.elasticsearch.xpack.core.ml.action.FlushJobAction;
@ -146,18 +146,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
@ -267,6 +256,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Stream;
import static java.util.stream.Collectors.toList;
public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPlugin {
@ -474,7 +466,8 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
@Override
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return Arrays.asList(
return Stream.concat(
Arrays.asList(
// graph
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.GRAPH, GraphFeatureSetUsage::new),
// logstash
@ -502,28 +495,6 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
// ML - Data frame evaluation
new NamedWriteableRegistry.Entry(
Evaluation.class,
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification.NAME.getPreferredName(),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification::new),
new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix.Result::new),
new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, Accuracy.NAME.getPreferredName(), Accuracy.Result::new),
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new),
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new),
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new),
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix.Result::new),
// ML - Inference preprocessing
new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
@ -628,7 +599,9 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ANALYTICS, AnalyticsFeatureSetUsage::new),
// Enrich
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ENRICH, EnrichFeatureSet.Usage::new)
);
).stream(),
MlEvaluationNamedXContentProvider.getNamedWriteables().stream()
).collect(toList());
}
@Override

View File

@ -7,12 +7,14 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -76,8 +78,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (EvaluationMetric metric : getMetrics()) {
// Fetch aggregations requested by individual metrics
List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
aggs.forEach(searchSourceBuilder::aggregation);
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
aggs.v1().forEach(searchSourceBuilder::aggregation);
aggs.v2().forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}

View File

@ -6,10 +6,12 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import java.util.List;
import java.util.Optional;
@ -30,7 +32,7 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
* @param predictedField the field that stores the predicted value (class name or probability)
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
/**
* Processes given aggregations as a step towards computing result

View File

@ -5,109 +5,179 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
/**
* Constructs the name under which a metric (or metric result) is registered.
* The name is prefixed with evaluation name so that registered names are unique.
*
* @param evaluationName name of the evaluation
* @param metricName name of the metric
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
*/
public static String registeredMetricName(ParseField evaluationName, ParseField metricName) {
return registeredMetricName(evaluationName.getPreferredName(), metricName.getPreferredName());
}
/**
* Constructs the name under which a metric (or metric result) is registered.
* The name is prefixed with evaluation name so that registered names are unique.
*
* @param evaluationName name of the evaluation
* @param metricName name of the metric
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
*/
public static String registeredMetricName(String evaluationName, String metricName) {
return evaluationName + "." + metricName;
}
@Override
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
return Arrays.asList(
// Evaluations
new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent),
new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent),
new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent),
// Evaluations
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
BinarySoftClassification::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
// Soft classification metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME)),
AucRoc::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, Precision.NAME)),
Precision::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, Recall.NAME)),
Recall::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME)),
ConfusionMatrix::fromXContent),
// Soft classification metrics
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
ConfusionMatrix::fromXContent));
// Classification metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)),
MulticlassConfusionMatrix::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)),
Accuracy::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent),
// Classification metrics
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
MulticlassConfusionMatrix::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
// Regression metrics
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
return namedXContent;
// Regression metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)),
MeanSquaredError::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
RSquared::fromXContent)
);
}
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return Arrays.asList(
// Evaluations
new NamedWriteableRegistry.Entry(Evaluation.class,
BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new),
new NamedWriteableRegistry.Entry(Evaluation.class,
Classification.NAME.getPreferredName(),
Classification::new),
new NamedWriteableRegistry.Entry(Evaluation.class,
Regression.NAME.getPreferredName(),
Regression::new),
// Evaluations
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
Classification::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
// Evaluation metrics
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
AucRoc::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(BinarySoftClassification.NAME, Precision.NAME),
Precision::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(BinarySoftClassification.NAME, Recall.NAME),
Recall::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
ConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
MulticlassConfusionMatrix::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Classification.NAME, Accuracy.NAME),
Accuracy::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
MeanSquaredError::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared::new),
// Evaluation Metrics
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
AucRoc::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(),
Precision::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(),
Recall::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
RSquared.NAME.getPreferredName(),
RSquared::new));
// Evaluation Metrics Results
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
AucRoc.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME,
ScoreByThresholdResult::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MulticlassConfusionMatrix.NAME.getPreferredName(),
MulticlassConfusionMatrix.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
Accuracy.NAME.getPreferredName(),
Accuracy.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
RSquared.NAME.getPreferredName(),
RSquared.Result::new));
return namedWriteables;
// Evaluation metrics results
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
AucRoc.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(BinarySoftClassification.NAME, ScoreByThresholdResult.NAME),
ScoreByThresholdResult::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
ConfusionMatrix.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
MulticlassConfusionMatrix.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Classification.NAME, Accuracy.NAME),
Accuracy.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
MeanSquaredError.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared.Result::new)
);
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -18,8 +19,10 @@ import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -34,6 +37,7 @@ import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* {@link Accuracy} is a metric that answers the question:
@ -41,7 +45,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
*
* equation: accuracy = 1/n * Σ(y == y´)
*/
public class Accuracy implements ClassificationMetric {
public class Accuracy implements EvaluationMetric {
public static final ParseField NAME = new ParseField("accuracy");
@ -68,7 +72,7 @@ public class Accuracy implements ClassificationMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Classification.NAME, NAME);
}
@Override
@ -77,16 +81,18 @@ public class Accuracy implements ClassificationMetric {
}
@Override
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (result != null) {
return Collections.emptyList();
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
Script accuracyScript = new Script(buildScript(actualField, predictedField));
return Arrays.asList(
AggregationBuilders.terms(CLASSES_AGG_NAME)
.field(actualField)
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript));
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(CLASSES_AGG_NAME)
.field(actualField)
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
Collections.emptyList());
}
@Override
@ -169,7 +175,7 @@ public class Accuracy implements ClassificationMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Classification.NAME, NAME);
}
@Override

View File

@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -20,6 +21,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Evaluation of classification results.
*/
@ -33,13 +36,13 @@ public class Classification implements Evaluation {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<ClassificationMetric>) a[2]));
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS);
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
}
public static Classification fromXContent(XContentParser parser) {
@ -61,22 +64,22 @@ public class Classification implements Evaluation {
/**
* The list of metrics to calculate
*/
private final List<ClassificationMetric> metrics;
private final List<EvaluationMetric> metrics;
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
}
private static List<ClassificationMetric> defaultMetrics() {
private static List<EvaluationMetric> defaultMetrics() {
return Arrays.asList(new MulticlassConfusionMatrix());
}
public Classification(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedField = in.readString();
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
}
@Override
@ -95,7 +98,7 @@ public class Classification implements Evaluation {
}
@Override
public List<ClassificationMetric> getMetrics() {
public List<EvaluationMetric> getMetrics() {
return metrics;
}
@ -118,8 +121,8 @@ public class Classification implements Evaluation {
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
builder.startObject(METRICS.getPreferredName());
for (ClassificationMetric metric : metrics) {
builder.field(metric.getWriteableName(), metric);
for (EvaluationMetric metric : metrics) {
builder.field(metric.getName(), metric);
}
builder.endObject();

View File

@ -1,11 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
public interface ClassificationMetric extends EvaluationMetric {
}

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -19,10 +20,12 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -38,13 +41,14 @@ import java.util.stream.Collectors;
import static java.util.Comparator.comparing;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* {@link MulticlassConfusionMatrix} is a metric that answers the question:
* "How many examples belonging to class X were classified as Y by the classifier?"
* "How many documents belonging to class X were classified as Y by the classifier?"
* for all the possible class pairs {X, Y}.
*/
public class MulticlassConfusionMatrix implements ClassificationMetric {
public class MulticlassConfusionMatrix implements EvaluationMetric {
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
@ -92,7 +96,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Classification.NAME, NAME);
}
@Override
@ -105,13 +109,15 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
}
@Override
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (topActualClassNames == null) { // This is step 1
return Arrays.asList(
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size));
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(size)),
Collections.emptyList());
}
if (result == null) { // This is step 2
KeyedFilter[] keyedFiltersActual =
@ -122,15 +128,17 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
return Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
.field(actualField),
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY)));
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
.field(actualField),
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
.otherBucket(true)
.otherBucketKey(OTHER_BUCKET_KEY))),
Collections.emptyList());
}
return Collections.emptyList();
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
@Override
@ -232,7 +240,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Classification.NAME, NAME);
}
@Override
@ -301,7 +309,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
/** Name of the actual class. */
private final String actualClass;
/** Number of documents (examples) belonging to the {code actualClass} class. */
/** Number of documents belonging to the {code actualClass} class. */
private final long actualClassDocCount;
/** List of predicted classes. */
private final List<PredictedClass> predictedClasses;

View File

@ -0,0 +1,347 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* {@link Precision} is a metric that answers the question:
* "What fraction of documents classified as X actually belongs to X?"
* for any given class X
*
* equation: precision(X) = TP(X) / (TP(X) + FP(X))
* where: TP(X) - number of true positives wrt X
* FP(X) - number of false positives wrt X
*/
public class Precision implements EvaluationMetric {
public static final ParseField NAME = new ParseField("precision");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String AGG_NAME_PREFIX = "classification_precision_";
static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}
private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);
public static Precision fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
private final int maxClassesCardinality;
private String actualField;
private List<String> topActualClassNames;
private EvaluationMetricResult result;
public Precision() {
this((Integer) null);
}
// Visible for testing
public Precision(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Precision(StreamInput in) throws IOException {
this.maxClassesCardinality = in.readVInt();
}
@Override
public String getWriteableName() {
return registeredMetricName(Classification.NAME, NAME);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField;
if (topActualClassNames == null) { // This is step 1
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
.field(actualField)
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
.size(maxClassesCardinality)),
Collections.emptyList());
}
if (result == null) { // This is step 2
KeyedFilter[] keyedFiltersPredicted =
topActualClassNames.stream()
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
.toArray(KeyedFilter[]::new);
Script script = buildScript(actualField, predictedField);
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)
.subAggregation(AggregationBuilders.avg(PER_PREDICTED_CLASS_PRECISION_AGG_NAME).script(script))),
Arrays.asList(
PipelineAggregatorBuilders.avgBucket(
AVG_PRECISION_AGG_NAME,
BY_PREDICTED_CLASS_AGG_NAME + ">" + PER_PREDICTED_CLASS_PRECISION_AGG_NAME)));
}
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
@Override
public void process(Aggregations aggs) {
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
// This means there were more than {@code maxClassesCardinality} buckets.
// We cannot calculate average precision accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
}
topActualClassNames =
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
}
if (result == null &&
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
NumericMetricsAggregation.SingleValue avgPrecisionAgg = aggs.get(AVG_PRECISION_AGG_NAME);
List<PerClassResult> classes = new ArrayList<>(byPredictedClassAgg.getBuckets().size());
for (Filters.Bucket bucket : byPredictedClassAgg.getBuckets()) {
String className = bucket.getKeyAsString();
NumericMetricsAggregation.SingleValue precisionAgg = bucket.getAggregations().get(PER_PREDICTED_CLASS_PRECISION_AGG_NAME);
double precision = precisionAgg.value();
if (Double.isFinite(precision)) {
classes.add(new PerClassResult(className, precision));
}
}
result = new Result(classes, avgPrecisionAgg.value());
}
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(maxClassesCardinality);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hashCode(NAME.getPreferredName());
}
public static class Result implements EvaluationMetricResult {
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField AVG_PRECISION = new ParseField("avg_precision");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("precision_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), AVG_PRECISION);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Average of per-class precisions. */
private final double avgPrecision;
public Result(List<PerClassResult> classes, double avgPrecision) {
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
this.avgPrecision = avgPrecision;
}
public Result(StreamInput in) throws IOException {
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
this.avgPrecision = in.readDouble();
}
@Override
public String getWriteableName() {
return registeredMetricName(Classification.NAME, NAME);
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
public List<PerClassResult> getClasses() {
return classes;
}
public double getAvgPrecision() {
return avgPrecision;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(classes);
out.writeDouble(avgPrecision);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASSES.getPreferredName(), classes);
builder.field(AVG_PRECISION.getPreferredName(), avgPrecision);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.classes, that.classes)
&& this.avgPrecision == that.avgPrecision;
}
@Override
public int hashCode() {
return Objects.hash(classes, avgPrecision);
}
}
public static class PerClassResult implements ToXContentObject, Writeable {
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField PRECISION = new ParseField("precision");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), PRECISION);
}
/** Name of the class. */
private final String className;
/** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */
private final double precision;
public PerClassResult(String className, double precision) {
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
this.precision = precision;
}
public PerClassResult(StreamInput in) throws IOException {
this.className = in.readString();
this.precision = in.readDouble();
}
public String getClassName() {
return className;
}
public double getPrecision() {
return precision;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeDouble(precision);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(PRECISION.getPreferredName(), precision);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.precision == that.precision;
}
@Override
public int hashCode() {
return Objects.hash(className, precision);
}
}
}

View File

@ -0,0 +1,321 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* {@link Recall} is a metric that answers the question:
* "What fraction of documents belonging to X have been predicted as X by the classifier?"
* for any given class X
*
* equation: recall(X) = TP(X) / (TP(X) + FN(X))
* where: TP(X) - number of true positives wrt X
* FN(X) - number of false negatives wrt X
*/
public class Recall implements EvaluationMetric {
public static final ParseField NAME = new ParseField("recall");
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
private static final String AGG_NAME_PREFIX = "classification_recall_";
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
private static Script buildScript(Object...args) {
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
}
private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);
public static Recall fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
private final int maxClassesCardinality;
private String actualField;
private EvaluationMetricResult result;
public Recall() {
this((Integer) null);
}
// Visible for testing
public Recall(@Nullable Integer maxClassesCardinality) {
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
}
public Recall(StreamInput in) throws IOException {
this.maxClassesCardinality = in.readVInt();
}
@Override
public String getWriteableName() {
return registeredMetricName(Classification.NAME, NAME);
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
this.actualField = actualField;
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
Script script = buildScript(actualField, predictedField);
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
.field(actualField)
.size(maxClassesCardinality)
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
Arrays.asList(
PipelineAggregatorBuilders.avgBucket(
AVG_RECALL_AGG_NAME,
BY_ACTUAL_CLASS_AGG_NAME + ">" + PER_ACTUAL_CLASS_RECALL_AGG_NAME)));
}
@Override
public void process(Aggregations aggs) {
if (result == null &&
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
if (byActualClassAgg.getSumOfOtherDocCounts() > 0) {
// This means there were more than {@code maxClassesCardinality} buckets.
// We cannot calculate average recall accurately, so we fail.
throw ExceptionsHelper.badRequestException(
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
}
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
for (Terms.Bucket bucket : byActualClassAgg.getBuckets()) {
String className = bucket.getKeyAsString();
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
classes.add(new PerClassResult(className, recallAgg.value()));
}
result = new Result(classes, avgRecallAgg.value());
}
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(maxClassesCardinality);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hashCode(NAME.getPreferredName());
}
public static class Result implements EvaluationMetricResult {
private static final ParseField CLASSES = new ParseField("classes");
private static final ParseField AVG_RECALL = new ParseField("avg_recall");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("recall_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
static {
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
PARSER.declareDouble(constructorArg(), AVG_RECALL);
}
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/** List of per-class results. */
private final List<PerClassResult> classes;
/** Average of per-class recalls. */
private final double avgRecall;
public Result(List<PerClassResult> classes, double avgRecall) {
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
this.avgRecall = avgRecall;
}
public Result(StreamInput in) throws IOException {
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
this.avgRecall = in.readDouble();
}
@Override
public String getWriteableName() {
return registeredMetricName(Classification.NAME, NAME);
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
public List<PerClassResult> getClasses() {
return classes;
}
public double getAvgRecall() {
return avgRecall;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(classes);
out.writeDouble(avgRecall);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASSES.getPreferredName(), classes);
builder.field(AVG_RECALL.getPreferredName(), avgRecall);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(this.classes, that.classes)
&& this.avgRecall == that.avgRecall;
}
@Override
public int hashCode() {
return Objects.hash(classes, avgRecall);
}
}
public static class PerClassResult implements ToXContentObject, Writeable {
private static final ParseField CLASS_NAME = new ParseField("class_name");
private static final ParseField RECALL = new ParseField("recall");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
static {
PARSER.declareString(constructorArg(), CLASS_NAME);
PARSER.declareDouble(constructorArg(), RECALL);
}
/** Name of the class. */
private final String className;
/** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */
private final double recall;
public PerClassResult(String className, double recall) {
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
this.recall = recall;
}
public PerClassResult(StreamInput in) throws IOException {
this.className = in.readString();
this.recall = in.readDouble();
}
public String getClassName() {
return className;
}
public double getRecall() {
return recall;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(className);
out.writeDouble(recall);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(RECALL.getPreferredName(), recall);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PerClassResult that = (PerClassResult) o;
return Objects.equals(this.className, that.className)
&& this.recall == that.recall;
}
@Override
public int hashCode() {
return Objects.hash(className, recall);
}
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
@ -15,7 +16,9 @@ import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
@ -27,12 +30,14 @@ import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Calculates the mean squared error between two known numerical fields.
*
* equation: mse = 1/n * Σ(y - y´)^2
*/
public class MeanSquaredError implements RegressionMetric {
public class MeanSquaredError implements EvaluationMetric {
public static final ParseField NAME = new ParseField("mean_squared_error");
@ -62,11 +67,13 @@ public class MeanSquaredError implements RegressionMetric {
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (result != null) {
return Collections.emptyList();
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
return Tuple.tuple(
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))),
Collections.emptyList());
}
@Override
@ -82,7 +89,7 @@ public class MeanSquaredError implements RegressionMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Regression.NAME, NAME);
}
@Override
@ -125,7 +132,7 @@ public class MeanSquaredError implements RegressionMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Regression.NAME, NAME);
}
@Override

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
@ -15,9 +16,11 @@ import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
@ -29,6 +32,8 @@ import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Calculates R-Squared between two known numerical fields.
*
@ -37,7 +42,7 @@ import java.util.Optional;
* SSres = Σ(y - y´)^2, The residual sum of squares
* SStot = Σ(y - y_mean)^2, The total sum of squares
*/
public class RSquared implements RegressionMetric {
public class RSquared implements EvaluationMetric {
public static final ParseField NAME = new ParseField("r_squared");
@ -67,13 +72,15 @@ public class RSquared implements RegressionMetric {
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
if (result != null) {
return Collections.emptyList();
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
return Arrays.asList(
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
return Tuple.tuple(
Arrays.asList(
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)),
Collections.emptyList());
}
@Override
@ -97,7 +104,7 @@ public class RSquared implements RegressionMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Regression.NAME, NAME);
}
@Override
@ -140,7 +147,7 @@ public class RSquared implements RegressionMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(Regression.NAME, NAME);
}
@Override

View File

@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -20,6 +21,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Evaluation of regression results.
*/
@ -33,13 +36,13 @@ public class Regression implements Evaluation {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<RegressionMetric>) a[2]));
NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS);
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
}
public static Regression fromXContent(XContentParser parser) {
@ -61,22 +64,22 @@ public class Regression implements Evaluation {
/**
* The list of metrics to calculate
*/
private final List<RegressionMetric> metrics;
private final List<EvaluationMetric> metrics;
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
}
private static List<RegressionMetric> defaultMetrics() {
private static List<EvaluationMetric> defaultMetrics() {
return Arrays.asList(new MeanSquaredError(), new RSquared());
}
public Regression(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedField = in.readString();
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
}
@Override
@ -95,7 +98,7 @@ public class Regression implements Evaluation {
}
@Override
public List<RegressionMetric> getMetrics() {
public List<EvaluationMetric> getMetrics() {
return metrics;
}
@ -118,8 +121,8 @@ public class Regression implements Evaluation {
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
builder.startObject(METRICS.getPreferredName());
for (RegressionMetric metric : metrics) {
builder.field(metric.getWriteableName(), metric);
for (EvaluationMetric metric : metrics) {
builder.field(metric.getName(), metric);
}
builder.endObject();

View File

@ -1,11 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
public interface RegressionMetric extends EvaluationMetric {
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -15,6 +16,8 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -23,9 +26,9 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery;
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
public static final ParseField AT = new ParseField("at");
@ -63,11 +66,11 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
}
@Override
public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
if (result != null) {
return Collections.emptyList();
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
return aggsAt(actualField, predictedProbabilityField);
return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList());
}
@Override

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -18,8 +19,10 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -33,7 +36,8 @@ import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery;
/**
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
@ -53,7 +57,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassific
* When this is used for multi-class classification, it will calculate the ROC
* curve of each class versus the rest.
*/
public class AucRoc implements SoftClassificationMetric {
public class AucRoc implements EvaluationMetric {
public static final ParseField NAME = new ParseField("auc_roc");
@ -88,7 +92,7 @@ public class AucRoc implements SoftClassificationMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override
@ -123,9 +127,9 @@ public class AucRoc implements SoftClassificationMetric {
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
if (result != null) {
return Collections.emptyList();
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
AggregationBuilder percentilesForClassValueAgg =
@ -138,7 +142,9 @@ public class AucRoc implements SoftClassificationMetric {
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
.subAggregation(
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
return Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg);
return Tuple.tuple(
Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg),
Collections.emptyList());
}
@Override
@ -330,7 +336,7 @@ public class AucRoc implements SoftClassificationMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override

View File

@ -12,7 +12,10 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
@ -20,6 +23,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Evaluation of binary soft classification methods, e.g. outlier detection.
* This is useful to evaluate problems where a model outputs a probability of whether
@ -34,19 +39,23 @@ public class BinarySoftClassification implements Evaluation {
private static final ParseField METRICS = new ParseField("metrics");
public static final ConstructingObjectParser<BinarySoftClassification, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List<SoftClassificationMetric>) a[2]));
NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS);
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
}
public static BinarySoftClassification fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
static QueryBuilder actualIsTrueQuery(String actualField) {
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
}
/**
* The field where the actual class is marked up.
* The value of this field is assumed to either be 1 or 0, or true or false.
@ -61,16 +70,16 @@ public class BinarySoftClassification implements Evaluation {
/**
* The list of metrics to calculate
*/
private final List<SoftClassificationMetric> metrics;
private final List<EvaluationMetric> metrics;
public BinarySoftClassification(String actualField, String predictedProbabilityField,
@Nullable List<SoftClassificationMetric> metrics) {
@Nullable List<EvaluationMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics);
}
private static List<SoftClassificationMetric> defaultMetrics() {
private static List<EvaluationMetric> defaultMetrics() {
return Arrays.asList(
new AucRoc(false),
new Precision(Arrays.asList(0.25, 0.5, 0.75)),
@ -81,7 +90,7 @@ public class BinarySoftClassification implements Evaluation {
public BinarySoftClassification(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedProbabilityField = in.readString();
this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
}
@Override
@ -100,7 +109,7 @@ public class BinarySoftClassification implements Evaluation {
}
@Override
public List<SoftClassificationMetric> getMetrics() {
public List<EvaluationMetric> getMetrics() {
return metrics;
}
@ -123,7 +132,7 @@ public class BinarySoftClassification implements Evaluation {
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
builder.startObject(METRICS.getPreferredName());
for (SoftClassificationMetric metric : metrics) {
for (EvaluationMetric metric : metrics) {
builder.field(metric.getName(), metric);
}
builder.endObject();

View File

@ -21,6 +21,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
public static final ParseField NAME = new ParseField("confusion_matrix");
@ -46,7 +48,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override
@ -129,7 +131,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override

View File

@ -19,6 +19,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
public class Precision extends AbstractConfusionMatrixMetric {
public static final ParseField NAME = new ParseField("precision");
@ -44,7 +46,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override

View File

@ -19,6 +19,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
public class Recall extends AbstractConfusionMatrixMetric {
public static final ParseField NAME = new ParseField("recall");
@ -44,7 +46,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
@Override
public String getWriteableName() {
return NAME.getPreferredName();
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -13,9 +14,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
public class ScoreByThresholdResult implements EvaluationMetricResult {
public static final String NAME = "score_by_threshold_result";
public static final ParseField NAME = new ParseField("score_by_threshold_result");
private final String name;
private final double[] thresholds;
@ -36,7 +39,7 @@ public class ScoreByThresholdResult implements EvaluationMetricResult {
@Override
public String getWriteableName() {
return NAME;
return registeredMetricName(BinarySoftClassification.NAME, NAME);
}
@Override

View File

@ -1,17 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
public interface SoftClassificationMetric extends EvaluationMetric {
static QueryBuilder actualIsTrueQuery(String actualField) {
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
}
}

View File

@ -31,7 +31,7 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(MlEvaluationNamedXContentProvider.getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables);
}
@ -46,13 +46,11 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
@Override
protected Request createTestInstance() {
Request request = new Request();
int indicesCount = randomIntBetween(1, 5);
List<String> indices = new ArrayList<>(indicesCount);
for (int i = 0; i < indicesCount; i++) {
indices.add(randomAlphaOfLength(10));
}
request.setIndices(indices);
QueryProvider queryProvider = null;
if (randomBoolean()) {
try {
@ -62,10 +60,11 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
throw new UncheckedIOException(e);
}
}
request.setQueryProvider(queryProvider);
Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
request.setEvaluation(evaluation);
return request;
return new Request()
.setIndices(indices)
.setQueryProvider(queryProvider)
.setEvaluation(evaluation);
}
@Override

View File

@ -11,7 +11,10 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
@ -22,7 +25,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
@ -30,11 +33,13 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
String evaluationName = randomAlphaOfLength(10);
List<EvaluationMetricResult> metrics =
Arrays.asList(
AccuracyResultTests.createRandom(),
PrecisionResultTests.createRandom(),
RecallResultTests.createRandom(),
MulticlassConfusionMatrixResultTests.createRandom(),
new MeanSquaredError.Result(randomDouble()),
new RSquared.Result(randomDouble()));
int numMetrics = randomIntBetween(0, metrics.size());
return new Response(evaluationName, metrics.subList(0, numMetrics));
return new Response(evaluationName, randomSubsetOf(metrics));
}
@Override

View File

@ -17,15 +17,9 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accuracy.Result> {
public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
}
@Override
protected Accuracy.Result createTestInstance() {
public static Result createRandom() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
@ -38,7 +32,17 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accurac
}
@Override
protected Writeable.Reader<Accuracy.Result> instanceReader() {
return Accuracy.Result::new;
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
protected Result createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Result> instanceReader() {
return Result::new;
}
}

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
@ -19,8 +20,10 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
@ -42,7 +45,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
@ -51,10 +54,12 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
}
public static Classification createRandom() {
List<ClassificationMetric> metrics =
List<EvaluationMetric> metrics =
randomSubsetOf(
Arrays.asList(
AccuracyTests.createRandom(),
PrecisionTests.createRandom(),
RecallTests.createRandom(),
MulticlassConfusionMatrixTests.createRandom()));
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
}
@ -101,10 +106,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
}
public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() {
ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
EvaluationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
EvaluationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
assertThat(metric1.getResult(), isEmpty());
@ -168,7 +173,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
* Number of steps is configurable.
* Upon reaching the last step, the result is produced.
*/
private static class FakeClassificationMetric implements ClassificationMetric {
private static class FakeClassificationMetric implements EvaluationMetric {
private final String name;
private final int numSteps;
@ -191,8 +196,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
return Collections.emptyList();
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
@Override

View File

@ -6,10 +6,12 @@
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
@ -25,9 +27,9 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregati
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
@ -74,8 +76,8 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
public void testAggs() {
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
assertThat(aggs, is(not(empty())));
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs("act", "pred");
assertThat(aggs, isTuple(not(empty()), empty()));
assertThat(confusionMatrix.getResult(), isEmpty());
}
@ -109,7 +111,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
assertThat(
@ -151,7 +153,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
confusionMatrix.process(aggs);
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
assertThat(

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class PrecisionResultTests extends AbstractWireSerializingTestCase<Result> {
public static Result createRandom() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double precision = randomDoubleBetween(0.0, 1.0, true);
classes.add(new PerClassResult(classNames.get(i), precision));
}
double avgPrecision = randomDoubleBetween(0.0, 1.0, true);
return new Result(classes, avgPrecision);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
protected Result createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Result> instanceReader() {
return Result::new;
}
}

View File

@ -0,0 +1,118 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
@Override
protected Precision doParseInstance(XContentParser parser) throws IOException {
return Precision.fromXContent(parser);
}
@Override
protected Precision createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Precision> instanceReader() {
return Precision::new;
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
public static Precision createRandom() {
return new Precision();
}
public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME),
mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Precision precision = new Precision();
precision.process(aggs);
assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty()));
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
}
public void testProcess_GivenMissingAgg() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Precision precision = new Precision();
precision.process(aggs);
assertThat(precision.getResult(), isEmpty());
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Precision precision = new Precision();
precision.process(aggs);
assertThat(precision.getResult(), isEmpty());
}
}
public void testProcess_GivenAggOfWrongType() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
mockFilters(Precision.AVG_PRECISION_AGG_NAME)
));
Precision precision = new Precision();
precision.process(aggs);
assertThat(precision.getResult(), isEmpty());
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue(Precision.BY_PREDICTED_CLASS_AGG_NAME, 1.0),
mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123)
));
Precision precision = new Precision();
precision.process(aggs);
assertThat(precision.getResult(), isEmpty());
}
}
public void testProcess_GivenCardinalityTooHigh() {
Aggregations aggs =
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
Precision precision = new Precision();
precision.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}
}

View File

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.PerClassResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class RecallResultTests extends AbstractWireSerializingTestCase<Result> {
public static Result createRandom() {
int numClasses = randomIntBetween(2, 100);
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
List<PerClassResult> classes = new ArrayList<>(numClasses);
for (int i = 0; i < numClasses; i++) {
double recall = randomDoubleBetween(0.0, 1.0, true);
classes.add(new PerClassResult(classNames.get(i), recall));
}
double avgRecall = randomDoubleBetween(0.0, 1.0, true);
return new Result(classes, avgRecall);
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
protected Result createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Result> instanceReader() {
return Result::new;
}
}

View File

@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
public class RecallTests extends AbstractSerializingTestCase<Recall> {
@Override
protected Recall doParseInstance(XContentParser parser) throws IOException {
return Recall.fromXContent(parser);
}
@Override
protected Recall createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Recall> instanceReader() {
return Recall::new;
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
public static Recall createRandom() {
return new Recall();
}
public void testProcess() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Recall recall = new Recall();
recall.process(aggs);
assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty()));
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
}
public void testProcess_GivenMissingAgg() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Recall recall = new Recall();
recall.process(aggs);
assertThat(recall.getResult(), isEmpty());
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
Recall recall = new Recall();
recall.process(aggs);
assertThat(recall.getResult(), isEmpty());
}
}
public void testProcess_GivenAggOfWrongType() {
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
mockTerms(Recall.AVG_RECALL_AGG_NAME)
));
Recall recall = new Recall();
recall.process(aggs);
assertThat(recall.getResult(), isEmpty());
}
{
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue(Recall.BY_ACTUAL_CLASS_AGG_NAME, 1.0),
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)
));
Recall recall = new Recall();
recall.process(aggs);
assertThat(recall.getResult(), isEmpty());
}
}
public void testProcess_GivenCardinalityTooHigh() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
Recall recall = new Recall();
recall.aggs("foo", "bar");
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
}
}

View File

@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
import org.elasticsearch.common.collect.Tuple;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import java.util.Arrays;
public class TupleMatchers {
private static class TupleMatcher<V1, V2> extends TypeSafeMatcher<Tuple<? extends V1, ? extends V2>> {
private final Matcher<? super V1> v1Matcher;
private final Matcher<? super V2> v2Matcher;
private TupleMatcher(Matcher<? super V1> v1Matcher, Matcher<? super V2> v2Matcher) {
this.v1Matcher = v1Matcher;
this.v2Matcher = v2Matcher;
}
@Override
protected boolean matchesSafely(final Tuple<? extends V1, ? extends V2> item) {
return item != null && v1Matcher.matches(item.v1()) && v2Matcher.matches(item.v2());
}
@Override
public void describeTo(final Description description) {
description.appendText("expected tuple matching ").appendList("[", ", ", "]", Arrays.asList(v1Matcher, v2Matcher));
}
}
/**
* Creates a matcher that matches iff:
* 1. the examined tuple's <code>v1()</code> matches the specified <code>v1Matcher</code>
* and
* 2. the examined tuple's <code>v2()</code> matches the specified <code>v2Matcher</code>
* For example:
* <pre>assertThat(Tuple.tuple("myValue1", "myValue2"), isTuple(startsWith("my"), containsString("Val")))</pre>
*/
public static <V1, V2> TupleMatcher<? extends V1, ? extends V2> isTuple(Matcher<? super V1> v1Matcher, Matcher<? super V2> v2Matcher) {
return new TupleMatcher(v1Matcher, v2Matcher);
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException;
@ -29,7 +30,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
@ -38,7 +39,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
}
public static Regression createRandom() {
List<RegressionMetric> metrics = new ArrayList<>();
List<EvaluationMetric> metrics = new ArrayList<>();
if (randomBoolean()) {
metrics.add(MeanSquaredErrorTests.createRandom());
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException;
@ -29,7 +30,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
}
@Override
@ -38,7 +39,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
}
public static BinarySoftClassification createRandom() {
List<SoftClassificationMetric> metrics = new ArrayList<>();
List<EvaluationMetric> metrics = new ArrayList<>();
if (randomBoolean()) {
metrics.add(AucRocTests.createRandom());
}

View File

@ -5,22 +5,24 @@
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
import org.junit.After;
import org.junit.Before;
import java.util.Arrays;
import java.util.List;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@ -117,6 +119,69 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
}
public void testEvaluate_Precision() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
assertThat(
precisionResult.getClasses(),
equalTo(
Arrays.asList(
new Precision.PerClassResult("ant", 1.0 / 15),
new Precision.PerClassResult("cat", 1.0 / 15),
new Precision.PerClassResult("dog", 1.0 / 15),
new Precision.PerClassResult("fox", 1.0 / 15),
new Precision.PerClassResult("mouse", 1.0 / 15))));
assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
}
public void testEvaluate_Precision_CardinalityTooHigh() {
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX,
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision(4)))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
}
public void testEvaluate_Recall() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
assertThat(
recallResult.getClasses(),
equalTo(
Arrays.asList(
new Recall.PerClassResult("ant", 1.0 / 15),
new Recall.PerClassResult("cat", 1.0 / 15),
new Recall.PerClassResult("dog", 1.0 / 15),
new Recall.PerClassResult("fox", 1.0 / 15),
new Recall.PerClassResult("mouse", 1.0 / 15))));
assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
}
public void testEvaluate_Recall_CardinalityTooHigh() {
ElasticsearchStatusException e =
expectThrows(
ElasticsearchStatusException.class,
() -> evaluateDataFrame(
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall(4)))));
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
}
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
@ -132,50 +197,50 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new ActualClass("ant",
new MulticlassConfusionMatrix.ActualClass("ant",
15,
Arrays.asList(
new PredictedClass("ant", 1L),
new PredictedClass("cat", 4L),
new PredictedClass("dog", 3L),
new PredictedClass("fox", 2L),
new PredictedClass("mouse", 5L)),
new MulticlassConfusionMatrix.PredictedClass("ant", 1L),
new MulticlassConfusionMatrix.PredictedClass("cat", 4L),
new MulticlassConfusionMatrix.PredictedClass("dog", 3L),
new MulticlassConfusionMatrix.PredictedClass("fox", 2L),
new MulticlassConfusionMatrix.PredictedClass("mouse", 5L)),
0),
new ActualClass("cat",
new MulticlassConfusionMatrix.ActualClass("cat",
15,
Arrays.asList(
new PredictedClass("ant", 3L),
new PredictedClass("cat", 1L),
new PredictedClass("dog", 5L),
new PredictedClass("fox", 4L),
new PredictedClass("mouse", 2L)),
new MulticlassConfusionMatrix.PredictedClass("ant", 3L),
new MulticlassConfusionMatrix.PredictedClass("cat", 1L),
new MulticlassConfusionMatrix.PredictedClass("dog", 5L),
new MulticlassConfusionMatrix.PredictedClass("fox", 4L),
new MulticlassConfusionMatrix.PredictedClass("mouse", 2L)),
0),
new ActualClass("dog",
new MulticlassConfusionMatrix.ActualClass("dog",
15,
Arrays.asList(
new PredictedClass("ant", 4L),
new PredictedClass("cat", 2L),
new PredictedClass("dog", 1L),
new PredictedClass("fox", 5L),
new PredictedClass("mouse", 3L)),
new MulticlassConfusionMatrix.PredictedClass("ant", 4L),
new MulticlassConfusionMatrix.PredictedClass("cat", 2L),
new MulticlassConfusionMatrix.PredictedClass("dog", 1L),
new MulticlassConfusionMatrix.PredictedClass("fox", 5L),
new MulticlassConfusionMatrix.PredictedClass("mouse", 3L)),
0),
new ActualClass("fox",
new MulticlassConfusionMatrix.ActualClass("fox",
15,
Arrays.asList(
new PredictedClass("ant", 5L),
new PredictedClass("cat", 3L),
new PredictedClass("dog", 2L),
new PredictedClass("fox", 1L),
new PredictedClass("mouse", 4L)),
new MulticlassConfusionMatrix.PredictedClass("ant", 5L),
new MulticlassConfusionMatrix.PredictedClass("cat", 3L),
new MulticlassConfusionMatrix.PredictedClass("dog", 2L),
new MulticlassConfusionMatrix.PredictedClass("fox", 1L),
new MulticlassConfusionMatrix.PredictedClass("mouse", 4L)),
0),
new ActualClass("mouse",
new MulticlassConfusionMatrix.ActualClass("mouse",
15,
Arrays.asList(
new PredictedClass("ant", 2L),
new PredictedClass("cat", 5L),
new PredictedClass("dog", 4L),
new PredictedClass("fox", 3L),
new PredictedClass("mouse", 1L)),
new MulticlassConfusionMatrix.PredictedClass("ant", 2L),
new MulticlassConfusionMatrix.PredictedClass("cat", 5L),
new MulticlassConfusionMatrix.PredictedClass("dog", 4L),
new MulticlassConfusionMatrix.PredictedClass("fox", 3L),
new MulticlassConfusionMatrix.PredictedClass("mouse", 1L)),
0))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
@ -194,17 +259,26 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
assertThat(
confusionMatrixResult.getConfusionMatrix(),
equalTo(Arrays.asList(
new ActualClass("ant",
new MulticlassConfusionMatrix.ActualClass("ant",
15,
Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)),
Arrays.asList(
new MulticlassConfusionMatrix.PredictedClass("ant", 1L),
new MulticlassConfusionMatrix.PredictedClass("cat", 4L),
new MulticlassConfusionMatrix.PredictedClass("dog", 3L)),
7),
new ActualClass("cat",
new MulticlassConfusionMatrix.ActualClass("cat",
15,
Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)),
Arrays.asList(
new MulticlassConfusionMatrix.PredictedClass("ant", 3L),
new MulticlassConfusionMatrix.PredictedClass("cat", 1L),
new MulticlassConfusionMatrix.PredictedClass("dog", 5L)),
6),
new ActualClass("dog",
new MulticlassConfusionMatrix.ActualClass("dog",
15,
Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)),
Arrays.asList(
new MulticlassConfusionMatrix.PredictedClass("ant", 4L),
new MulticlassConfusionMatrix.PredictedClass("cat", 2L),
new MulticlassConfusionMatrix.PredictedClass("dog", 1L)),
8))));
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
}

View File

@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
import org.junit.After;
import java.util.ArrayList;
@ -450,9 +452,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
evaluateDataFrame(
destIndex,
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix())));
dependentVariable,
predictedClassField,
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
{ // Accuracy
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
@ -483,6 +487,24 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
}
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
{ // Precision
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
for (Precision.PerClassResult klass : precisionResult.getClasses()) {
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
assertThat(klass.getPrecision(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
}
}
{ // Recall
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
for (Recall.PerClassResult klass : recallResult.getClasses()) {
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
assertThat(klass.getRecall(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
}
}
}
protected String stateDocId() {

View File

@ -632,6 +632,58 @@ setup:
accuracy: 0.5 # 1 out of 2
overall_accuracy: 0.625 # 5 out of 8
---
"Test classification precision":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { "precision": {} }
}
}
}
- match:
classification.precision:
classes:
- class_name: "cat"
precision: 0.5 # 2 out of 4
- class_name: "dog"
precision: 0.6666666666666666 # 2 out of 3
- class_name: "mouse"
precision: 1.0 # 1 out of 1
avg_precision: 0.7222222222222222
---
"Test classification recall":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"classification": {
"actual_field": "classification_field_act.keyword",
"predicted_field": "classification_field_pred.keyword",
"metrics": { "recall": {} }
}
}
}
- match:
classification.recall:
classes:
- class_name: "cat"
recall: 0.6666666666666666 # 2 out of 3
- class_name: "dog"
recall: 0.6666666666666666 # 2 out of 3
- class_name: "mouse"
recall: 0.5 # 1 out of 2
avg_recall: 0.611111111111111
---
"Test classification multiclass_confusion_matrix":
- do:
ml.evaluate_data_frame: