mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-24 17:09:48 +00:00
This commit is contained in:
parent
903305284d
commit
cc4bc797f9
@ -35,6 +35,7 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
|
||||
|
||||
public class EvaluateDataFrameResponse implements ToXContentObject {
|
||||
@ -47,7 +48,7 @@ public class EvaluateDataFrameResponse implements ToXContentObject {
|
||||
ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
|
||||
String evaluationName = parser.currentName();
|
||||
parser.nextToken();
|
||||
Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric);
|
||||
Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, p -> parseMetric(evaluationName, p));
|
||||
List<EvaluationMetric.Result> knownMetrics =
|
||||
metrics.values().stream()
|
||||
.filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}.
|
||||
@ -56,10 +57,10 @@ public class EvaluateDataFrameResponse implements ToXContentObject {
|
||||
return new EvaluateDataFrameResponse(evaluationName, knownMetrics);
|
||||
}
|
||||
|
||||
private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException {
|
||||
private static EvaluationMetric.Result parseMetric(String evaluationName, XContentParser parser) throws IOException {
|
||||
String metricName = parser.currentName();
|
||||
try {
|
||||
return parser.namedObject(EvaluationMetric.Result.class, metricName, null);
|
||||
return parser.namedObject(EvaluationMetric.Result.class, registeredMetricName(evaluationName, metricName), null);
|
||||
} catch (NamedObjectNotFoundException e) {
|
||||
parser.skipChildren();
|
||||
// Metric name not recognized. Return {@code null} value here and filter it out later.
|
||||
|
@ -20,24 +20,36 @@ package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
/**
|
||||
* Constructs the name under which a metric (or metric result) is registered.
|
||||
* The name is prefixed with evaluation name so that registered names are unique.
|
||||
*
|
||||
* @param evaluationName name of the evaluation
|
||||
* @param metricName name of the metric
|
||||
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
|
||||
*/
|
||||
public static String registeredMetricName(String evaluationName, String metricName) {
|
||||
return evaluationName + "." + metricName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
return Arrays.asList(
|
||||
@ -47,39 +59,91 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
|
||||
// Evaluation metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(MulticlassConfusionMatrixMetric.NAME),
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
|
||||
AucRocMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
|
||||
PrecisionMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
|
||||
RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
|
||||
ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
||||
AccuracyMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
||||
MulticlassConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
|
||||
MeanSquaredErrorMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent),
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
||||
RSquaredMetric::fromXContent),
|
||||
// Evaluation metrics results
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
|
||||
AucRocMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(MulticlassConfusionMatrixMetric.NAME),
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
|
||||
PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
|
||||
RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
|
||||
ConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
|
||||
AccuracyMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
|
||||
MulticlassConfusionMatrixMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
|
||||
MeanSquaredErrorMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
||||
RSquaredMetric.Result::fromXContent)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,10 @@ import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* Evaluation of classification results.
|
||||
*/
|
||||
@ -48,10 +52,10 @@ public class Classification implements Evaluation {
|
||||
NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
|
||||
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(
|
||||
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
|
||||
}
|
||||
|
||||
public static Classification fromXContent(XContentParser parser) {
|
||||
|
@ -0,0 +1,201 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* {@link PrecisionMetric} is a metric that answers the question:
|
||||
* "What fraction of documents classified as X actually belongs to X?"
|
||||
* for any given class X
|
||||
*
|
||||
* equation: precision(X) = TP(X) / (TP(X) + FP(X))
|
||||
* where: TP(X) - number of true positives wrt X
|
||||
* FP(X) - number of false positives wrt X
|
||||
*/
|
||||
public class PrecisionMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "precision";
|
||||
|
||||
private static final ObjectParser<PrecisionMetric, Void> PARSER = new ObjectParser<>(NAME, true, PrecisionMetric::new);
|
||||
|
||||
public static PrecisionMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public PrecisionMetric() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
private static final ParseField CLASSES = new ParseField("classes");
|
||||
private static final ParseField AVG_PRECISION = new ParseField("avg_precision");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("precision_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), AVG_PRECISION);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of per-class results. */
|
||||
private final List<PerClassResult> classes;
|
||||
/** Average of per-class precisions. */
|
||||
private final double avgPrecision;
|
||||
|
||||
public Result(List<PerClassResult> classes, double avgPrecision) {
|
||||
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
|
||||
this.avgPrecision = avgPrecision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public List<PerClassResult> getClasses() {
|
||||
return classes;
|
||||
}
|
||||
|
||||
public double getAvgPrecision() {
|
||||
return avgPrecision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASSES.getPreferredName(), classes);
|
||||
builder.field(AVG_PRECISION.getPreferredName(), avgPrecision);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.classes, that.classes)
|
||||
&& this.avgPrecision == that.avgPrecision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classes, avgPrecision);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PerClassResult implements ToXContentObject {
|
||||
|
||||
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
private static final ParseField PRECISION = new ParseField("precision");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||
new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareDouble(constructorArg(), PRECISION);
|
||||
}
|
||||
|
||||
/** Name of the class. */
|
||||
private final String className;
|
||||
/** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */
|
||||
private final double precision;
|
||||
|
||||
public PerClassResult(String className, double precision) {
|
||||
this.className = Objects.requireNonNull(className);
|
||||
this.precision = precision;
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public double getPrecision() {
|
||||
return precision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(PRECISION.getPreferredName(), precision);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PerClassResult that = (PerClassResult) o;
|
||||
return Objects.equals(this.className, that.className)
|
||||
&& this.precision == that.precision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, precision);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,201 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* {@link RecallMetric} is a metric that answers the question:
|
||||
* "What fraction of documents belonging to X have been predicted as X by the classifier?"
|
||||
* for any given class X
|
||||
*
|
||||
* equation: recall(X) = TP(X) / (TP(X) + FN(X))
|
||||
* where: TP(X) - number of true positives wrt X
|
||||
* FN(X) - number of false negatives wrt X
|
||||
*/
|
||||
public class RecallMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "recall";
|
||||
|
||||
private static final ObjectParser<RecallMetric, Void> PARSER = new ObjectParser<>(NAME, true, RecallMetric::new);
|
||||
|
||||
public static RecallMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public RecallMetric() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
private static final ParseField CLASSES = new ParseField("classes");
|
||||
private static final ParseField AVG_RECALL = new ParseField("avg_recall");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("recall_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), AVG_RECALL);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of per-class results. */
|
||||
private final List<PerClassResult> classes;
|
||||
/** Average of per-class recalls. */
|
||||
private final double avgRecall;
|
||||
|
||||
public Result(List<PerClassResult> classes, double avgRecall) {
|
||||
this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
|
||||
this.avgRecall = avgRecall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public List<PerClassResult> getClasses() {
|
||||
return classes;
|
||||
}
|
||||
|
||||
public double getAvgRecall() {
|
||||
return avgRecall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASSES.getPreferredName(), classes);
|
||||
builder.field(AVG_RECALL.getPreferredName(), avgRecall);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.classes, that.classes)
|
||||
&& this.avgRecall == that.avgRecall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classes, avgRecall);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PerClassResult implements ToXContentObject {
|
||||
|
||||
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
private static final ParseField RECALL = new ParseField("recall");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||
new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareDouble(constructorArg(), RECALL);
|
||||
}
|
||||
|
||||
/** Name of the class. */
|
||||
private final String className;
|
||||
/** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */
|
||||
private final double recall;
|
||||
|
||||
public PerClassResult(String className, double recall) {
|
||||
this.className = Objects.requireNonNull(className);
|
||||
this.recall = recall;
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public double getRecall() {
|
||||
return recall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(RECALL.getPreferredName(), recall);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PerClassResult that = (PerClassResult) o;
|
||||
return Objects.equals(this.className, that.className)
|
||||
&& this.recall == that.recall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, recall);
|
||||
}
|
||||
}
|
||||
}
|
@ -33,6 +33,10 @@ import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* Evaluation of regression results.
|
||||
*/
|
||||
@ -49,10 +53,10 @@ public class Regression implements Evaluation {
|
||||
NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
|
||||
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(
|
||||
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
|
||||
}
|
||||
|
||||
public static Regression fromXContent(XContentParser parser) {
|
||||
|
@ -33,6 +33,7 @@ import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
@ -59,7 +60,8 @@ public class BinarySoftClassification implements Evaluation {
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD);
|
||||
PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS);
|
||||
PARSER.declareNamedObjects(
|
||||
optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), null), METRICS);
|
||||
}
|
||||
|
||||
public static BinarySoftClassification fromXContent(XContentParser parser) {
|
||||
|
@ -1860,6 +1860,70 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
new AccuracyMetric.ActualClass("ant", 1, 0.0))));
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6)); // 6 out of 10 examples were classified correctly
|
||||
}
|
||||
{ // Precision
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(
|
||||
actualClassField,
|
||||
predictedClassField,
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
||||
evaluateDataFrameResponse.getMetricByName(
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
|
||||
assertThat(
|
||||
precisionResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
|
||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
|
||||
}
|
||||
{ // Recall
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
indexName,
|
||||
null,
|
||||
new Classification(
|
||||
actualClassField,
|
||||
predictedClassField,
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
||||
evaluateDataFrameResponse.getMetricByName(
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
|
||||
assertThat(
|
||||
recallResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
// 3 out of 5 examples labeled as "cat" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
|
||||
// 3 out of 4 examples labeled as "dog" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
|
||||
// no examples labeled as "ant" were classified correctly
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(0.45));
|
||||
}
|
||||
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
||||
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
||||
new EvaluateDataFrameRequest(
|
||||
|
@ -128,6 +128,7 @@ import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
|
||||
import static org.hamcrest.CoreMatchers.endsWith;
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
@ -688,7 +689,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(51, namedXContents.size());
|
||||
assertEquals(55, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
@ -730,26 +731,36 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
AccuracyMetric.NAME,
|
||||
MulticlassConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
hasItems(
|
||||
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
|
||||
registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME),
|
||||
registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME),
|
||||
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
AccuracyMetric.NAME,
|
||||
MulticlassConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
hasItems(
|
||||
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
|
||||
registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME),
|
||||
registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME),
|
||||
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
|
||||
|
@ -3372,7 +3372,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
"predicted_class", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
new AccuracyMetric(), // <5>
|
||||
new MulticlassConfusionMatrixMetric(3)); // <6>
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6>
|
||||
new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7>
|
||||
new MulticlassConfusionMatrixMetric(3)); // <8>
|
||||
// end::evaluate-data-frame-evaluation-classification
|
||||
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||
@ -3382,16 +3384,34 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
|
||||
double accuracy = accuracyResult.getOverallAccuracy(); // <2>
|
||||
|
||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
|
||||
double precision = precisionResult.getAvgPrecision(); // <4>
|
||||
|
||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
|
||||
long otherActualClassCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
||||
response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
|
||||
double recall = recallResult.getAvgRecall(); // <6>
|
||||
|
||||
MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
|
||||
response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <7>
|
||||
|
||||
List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
|
||||
long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
|
||||
// end::evaluate-data-frame-results-classification
|
||||
|
||||
assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
|
||||
assertThat(accuracy, equalTo(0.6));
|
||||
|
||||
assertThat(
|
||||
precisionResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
||||
assertThat(precision, equalTo(0.675));
|
||||
|
||||
assertThat(
|
||||
recallResult.getMetricName(),
|
||||
equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
||||
assertThat(recall, equalTo(0.45));
|
||||
|
||||
assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
||||
assertThat(
|
||||
confusionMatrix,
|
||||
@ -3412,7 +3432,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
4L,
|
||||
Arrays.asList(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
|
||||
0L))));
|
||||
assertThat(otherActualClassCount, equalTo(0L));
|
||||
assertThat(otherClassesCount, equalTo(0L));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,6 +64,8 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<Eva
|
||||
metrics = randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AccuracyMetricResultTests.randomResult(),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetricResultTests.randomResult(),
|
||||
org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetricResultTests.randomResult(),
|
||||
MulticlassConfusionMatrixMetricResultTests.randomResult()));
|
||||
break;
|
||||
default:
|
||||
|
@ -41,6 +41,8 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
||||
randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AccuracyMetricTests.createRandom(),
|
||||
PrecisionMetricTests.createRandom(),
|
||||
RecallMetricTests.createRandom(),
|
||||
MulticlassConfusionMatrixMetricTests.createRandom()));
|
||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class PrecisionMetricResultTests extends AbstractXContentTestCase<Result> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static Result randomResult() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double precision = randomDoubleBetween(0.0, 1.0, true);
|
||||
classes.add(new PerClassResult(classNames.get(i), precision));
|
||||
}
|
||||
double avgPrecision = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(classes, avgPrecision);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result createTestInstance() {
|
||||
return randomResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class PrecisionMetricTests extends AbstractXContentTestCase<PrecisionMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
static PrecisionMetric createRandom() {
|
||||
return new PrecisionMetric();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PrecisionMetric createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected PrecisionMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return PrecisionMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class RecallMetricResultTests extends AbstractXContentTestCase<Result> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static Result randomResult() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double recall = randomDoubleBetween(0.0, 1.0, true);
|
||||
classes.add(new PerClassResult(classNames.get(i), recall));
|
||||
}
|
||||
double avgRecall = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(classes, avgRecall);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result createTestInstance() {
|
||||
return randomResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RecallMetricTests extends AbstractXContentTestCase<RecallMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
static RecallMetric createRandom() {
|
||||
return new RecallMetric();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RecallMetric createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RecallMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return RecallMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -53,7 +53,9 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
|
||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> Accuracy
|
||||
<6> Multiclass confusion matrix of size 3
|
||||
<6> Precision
|
||||
<7> Recall
|
||||
<8> Multiclass confusion matrix of size 3
|
||||
|
||||
===== Regression
|
||||
|
||||
@ -104,9 +106,13 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
|
||||
|
||||
<1> Fetching accuracy metric by name
|
||||
<2> Fetching the actual accuracy value
|
||||
<3> Fetching multiclass confusion matrix metric by name
|
||||
<4> Fetching the contents of the confusion matrix
|
||||
<5> Fetching the number of classes that were not included in the matrix
|
||||
<3> Fetching precision metric by name
|
||||
<4> Fetching the actual precision value
|
||||
<5> Fetching recall metric by name
|
||||
<6> Fetching the actual recall value
|
||||
<7> Fetching multiclass confusion matrix metric by name
|
||||
<8> Fetching the contents of the confusion matrix
|
||||
<9> Fetching the number of classes that were not included in the matrix
|
||||
|
||||
===== Regression
|
||||
|
||||
@ -118,4 +124,4 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
|
||||
<1> Fetching mean squared error metric by name
|
||||
<2> Fetching the actual mean squared error value
|
||||
<3> Fetching R squared metric by name
|
||||
<4> Fetching the actual R squared value
|
||||
<4> Fetching the actual R squared value
|
||||
|
@ -79,7 +79,6 @@ import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
|
||||
import org.elasticsearch.xpack.core.ml.MlMetadata;
|
||||
import org.elasticsearch.xpack.core.ml.MlTasks;
|
||||
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
|
||||
@ -91,6 +90,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.FlushJobAction;
|
||||
@ -146,18 +146,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
||||
@ -267,6 +256,9 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPlugin {
|
||||
|
||||
@ -474,7 +466,8 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
||||
|
||||
@Override
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
return Arrays.asList(
|
||||
return Stream.concat(
|
||||
Arrays.asList(
|
||||
// graph
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.GRAPH, GraphFeatureSetUsage::new),
|
||||
// logstash
|
||||
@ -502,28 +495,6 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
|
||||
new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
|
||||
// ML - Data frame evaluation
|
||||
new NamedWriteableRegistry.Entry(
|
||||
Evaluation.class,
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification.NAME.getPreferredName(),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification::new),
|
||||
new NamedWriteableRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.Result::new),
|
||||
new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, Accuracy.NAME.getPreferredName(), Accuracy.Result::new),
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new),
|
||||
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),
|
||||
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new),
|
||||
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new),
|
||||
new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix.Result::new),
|
||||
// ML - Inference preprocessing
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
|
||||
new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
|
||||
@ -628,7 +599,9 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ANALYTICS, AnalyticsFeatureSetUsage::new),
|
||||
// Enrich
|
||||
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ENRICH, EnrichFeatureSet.Usage::new)
|
||||
);
|
||||
).stream(),
|
||||
MlEvaluationNamedXContentProvider.getNamedWriteables().stream()
|
||||
).collect(toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -7,12 +7,14 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
@ -76,8 +78,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (EvaluationMetric metric : getMetrics()) {
|
||||
// Fetch aggregations requested by individual metrics
|
||||
List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
|
||||
aggs.v1().forEach(searchSourceBuilder::aggregation);
|
||||
aggs.v2().forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
@ -6,10 +6,12 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
@ -30,7 +32,7 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
|
||||
* @param predictedField the field that stores the predicted value (class name or probability)
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Processes given aggregations as a step towards computing result
|
||||
|
@ -5,109 +5,179 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
|
||||
|
||||
/**
|
||||
* Constructs the name under which a metric (or metric result) is registered.
|
||||
* The name is prefixed with evaluation name so that registered names are unique.
|
||||
*
|
||||
* @param evaluationName name of the evaluation
|
||||
* @param metricName name of the metric
|
||||
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
|
||||
*/
|
||||
public static String registeredMetricName(ParseField evaluationName, ParseField metricName) {
|
||||
return registeredMetricName(evaluationName.getPreferredName(), metricName.getPreferredName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs the name under which a metric (or metric result) is registered.
|
||||
* The name is prefixed with evaluation name so that registered names are unique.
|
||||
*
|
||||
* @param evaluationName name of the evaluation
|
||||
* @param metricName name of the metric
|
||||
* @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
|
||||
*/
|
||||
public static String registeredMetricName(String evaluationName, String metricName) {
|
||||
return evaluationName + "." + metricName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
return Arrays.asList(
|
||||
// Evaluations
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent),
|
||||
|
||||
// Evaluations
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
|
||||
BinarySoftClassification::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
|
||||
// Soft classification metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME)),
|
||||
AucRoc::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, Precision.NAME)),
|
||||
Precision::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, Recall.NAME)),
|
||||
Recall::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME)),
|
||||
ConfusionMatrix::fromXContent),
|
||||
|
||||
// Soft classification metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
|
||||
ConfusionMatrix::fromXContent));
|
||||
// Classification metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)),
|
||||
MulticlassConfusionMatrix::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)),
|
||||
Accuracy::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent),
|
||||
|
||||
// Classification metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
|
||||
MulticlassConfusionMatrix::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
|
||||
|
||||
// Regression metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
|
||||
|
||||
return namedXContent;
|
||||
// Regression metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)),
|
||||
MeanSquaredError::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
|
||||
RSquared::fromXContent)
|
||||
);
|
||||
}
|
||||
|
||||
public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
||||
return Arrays.asList(
|
||||
// Evaluations
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class,
|
||||
BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new),
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class,
|
||||
Classification.NAME.getPreferredName(),
|
||||
Classification::new),
|
||||
new NamedWriteableRegistry.Entry(Evaluation.class,
|
||||
Regression.NAME.getPreferredName(),
|
||||
Regression::new),
|
||||
|
||||
// Evaluations
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
|
||||
Classification::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
|
||||
// Evaluation metrics
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
|
||||
AucRoc::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, Precision.NAME),
|
||||
Precision::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, Recall.NAME),
|
||||
Recall::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
|
||||
ConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
|
||||
MulticlassConfusionMatrix::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Classification.NAME, Accuracy.NAME),
|
||||
Accuracy::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
||||
MeanSquaredError::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Regression.NAME, RSquared.NAME),
|
||||
RSquared::new),
|
||||
|
||||
// Evaluation Metrics
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
|
||||
AucRoc::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(),
|
||||
Precision::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(),
|
||||
Recall::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
RSquared.NAME.getPreferredName(),
|
||||
RSquared::new));
|
||||
|
||||
// Evaluation Metrics Results
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
|
||||
AucRoc.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME,
|
||||
ScoreByThresholdResult::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MulticlassConfusionMatrix.NAME.getPreferredName(),
|
||||
MulticlassConfusionMatrix.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
Accuracy.NAME.getPreferredName(),
|
||||
Accuracy.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
RSquared.NAME.getPreferredName(),
|
||||
RSquared.Result::new));
|
||||
|
||||
return namedWriteables;
|
||||
// Evaluation metrics results
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
|
||||
AucRoc.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, ScoreByThresholdResult.NAME),
|
||||
ScoreByThresholdResult::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
|
||||
ConfusionMatrix.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
|
||||
MulticlassConfusionMatrix.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Classification.NAME, Accuracy.NAME),
|
||||
Accuracy.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(
|
||||
Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
|
||||
org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
||||
MeanSquaredError.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Regression.NAME, RSquared.NAME),
|
||||
RSquared.Result::new)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
@ -18,8 +19,10 @@ import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
@ -34,6 +37,7 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* {@link Accuracy} is a metric that answers the question:
|
||||
@ -41,7 +45,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
|
||||
*
|
||||
* equation: accuracy = 1/n * Σ(y == y´)
|
||||
*/
|
||||
public class Accuracy implements ClassificationMetric {
|
||||
public class Accuracy implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("accuracy");
|
||||
|
||||
@ -68,7 +72,7 @@ public class Accuracy implements ClassificationMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -77,16 +81,18 @@ public class Accuracy implements ClassificationMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
Script accuracyScript = new Script(buildScript(actualField, predictedField));
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.terms(CLASSES_AGG_NAME)
|
||||
.field(actualField)
|
||||
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
|
||||
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript));
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(CLASSES_AGG_NAME)
|
||||
.field(actualField)
|
||||
.subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
|
||||
AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -169,7 +175,7 @@ public class Accuracy implements ClassificationMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -20,6 +21,8 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Evaluation of classification results.
|
||||
*/
|
||||
@ -33,13 +36,13 @@ public class Classification implements Evaluation {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<ClassificationMetric>) a[2]));
|
||||
NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS);
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
|
||||
}
|
||||
|
||||
public static Classification fromXContent(XContentParser parser) {
|
||||
@ -61,22 +64,22 @@ public class Classification implements Evaluation {
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<ClassificationMetric> metrics;
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
|
||||
public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.metrics = initMetrics(metrics, Classification::defaultMetrics);
|
||||
}
|
||||
|
||||
private static List<ClassificationMetric> defaultMetrics() {
|
||||
private static List<EvaluationMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MulticlassConfusionMatrix());
|
||||
}
|
||||
|
||||
public Classification(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedField = in.readString();
|
||||
this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
|
||||
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -95,7 +98,7 @@ public class Classification implements Evaluation {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ClassificationMetric> getMetrics() {
|
||||
public List<EvaluationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@ -118,8 +121,8 @@ public class Classification implements Evaluation {
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (ClassificationMetric metric : metrics) {
|
||||
builder.field(metric.getWriteableName(), metric);
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
public interface ClassificationMetric extends EvaluationMetric {
|
||||
}
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
@ -19,10 +20,12 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.BucketOrder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
@ -38,13 +41,14 @@ import java.util.stream.Collectors;
|
||||
import static java.util.Comparator.comparing;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* {@link MulticlassConfusionMatrix} is a metric that answers the question:
|
||||
* "How many examples belonging to class X were classified as Y by the classifier?"
|
||||
* "How many documents belonging to class X were classified as Y by the classifier?"
|
||||
* for all the possible class pairs {X, Y}.
|
||||
*/
|
||||
public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
public class MulticlassConfusionMatrix implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
|
||||
|
||||
@ -92,7 +96,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -105,13 +109,15 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (topActualClassNames == null) { // This is step 1
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size));
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(size)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
KeyedFilter[] keyedFiltersActual =
|
||||
@ -122,15 +128,17 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
topActualClassNames.stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
|
||||
.field(actualField),
|
||||
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
|
||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
|
||||
.otherBucket(true)
|
||||
.otherBucketKey(OTHER_BUCKET_KEY)));
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
|
||||
.field(actualField),
|
||||
AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
|
||||
.subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
|
||||
.otherBucket(true)
|
||||
.otherBucketKey(OTHER_BUCKET_KEY))),
|
||||
Collections.emptyList());
|
||||
}
|
||||
return Collections.emptyList();
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -232,7 +240,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -301,7 +309,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
|
||||
|
||||
/** Name of the actual class. */
|
||||
private final String actualClass;
|
||||
/** Number of documents (examples) belonging to the {code actualClass} class. */
|
||||
/** Number of documents belonging to the {code actualClass} class. */
|
||||
private final long actualClassDocCount;
|
||||
/** List of predicted classes. */
|
||||
private final List<PredictedClass> predictedClasses;
|
||||
|
@ -0,0 +1,347 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.BucketOrder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* {@link Precision} is a metric that answers the question:
|
||||
* "What fraction of documents classified as X actually belongs to X?"
|
||||
* for any given class X
|
||||
*
|
||||
* equation: precision(X) = TP(X) / (TP(X) + FP(X))
|
||||
* where: TP(X) - number of true positives wrt X
|
||||
* FP(X) - number of false positives wrt X
|
||||
*/
|
||||
public class Precision implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("precision");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
private static final String AGG_NAME_PREFIX = "classification_precision_";
|
||||
static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
|
||||
static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
|
||||
static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
|
||||
static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
|
||||
|
||||
private static Script buildScript(Object...args) {
|
||||
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||
}
|
||||
|
||||
private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);
|
||||
|
||||
public static Precision fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
private final int maxClassesCardinality;
|
||||
private String actualField;
|
||||
private List<String> topActualClassNames;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public Precision() {
|
||||
this((Integer) null);
|
||||
}
|
||||
|
||||
// Visible for testing
|
||||
public Precision(@Nullable Integer maxClassesCardinality) {
|
||||
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
|
||||
}
|
||||
|
||||
public Precision(StreamInput in) throws IOException {
|
||||
this.maxClassesCardinality = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField = actualField;
|
||||
if (topActualClassNames == null) { // This is step 1
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
|
||||
.field(actualField)
|
||||
.order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true)))
|
||||
.size(maxClassesCardinality)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
if (result == null) { // This is step 2
|
||||
KeyedFilter[] keyedFiltersPredicted =
|
||||
topActualClassNames.stream()
|
||||
.map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
|
||||
.toArray(KeyedFilter[]::new);
|
||||
Script script = buildScript(actualField, predictedField);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)
|
||||
.subAggregation(AggregationBuilders.avg(PER_PREDICTED_CLASS_PRECISION_AGG_NAME).script(script))),
|
||||
Arrays.asList(
|
||||
PipelineAggregatorBuilders.avgBucket(
|
||||
AVG_PRECISION_AGG_NAME,
|
||||
BY_PREDICTED_CLASS_AGG_NAME + ">" + PER_PREDICTED_CLASS_PRECISION_AGG_NAME)));
|
||||
}
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
|
||||
Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
|
||||
if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
|
||||
// This means there were more than {@code maxClassesCardinality} buckets.
|
||||
// We cannot calculate average precision accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
|
||||
}
|
||||
topActualClassNames =
|
||||
topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
|
||||
}
|
||||
if (result == null &&
|
||||
aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
|
||||
aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||
Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
|
||||
NumericMetricsAggregation.SingleValue avgPrecisionAgg = aggs.get(AVG_PRECISION_AGG_NAME);
|
||||
List<PerClassResult> classes = new ArrayList<>(byPredictedClassAgg.getBuckets().size());
|
||||
for (Filters.Bucket bucket : byPredictedClassAgg.getBuckets()) {
|
||||
String className = bucket.getKeyAsString();
|
||||
NumericMetricsAggregation.SingleValue precisionAgg = bucket.getAggregations().get(PER_PREDICTED_CLASS_PRECISION_AGG_NAME);
|
||||
double precision = precisionAgg.value();
|
||||
if (Double.isFinite(precision)) {
|
||||
classes.add(new PerClassResult(className, precision));
|
||||
}
|
||||
}
|
||||
result = new Result(classes, avgPrecisionAgg.value());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(maxClassesCardinality);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME.getPreferredName());
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final ParseField CLASSES = new ParseField("classes");
|
||||
private static final ParseField AVG_PRECISION = new ParseField("avg_precision");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("precision_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), AVG_PRECISION);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of per-class results. */
|
||||
private final List<PerClassResult> classes;
|
||||
/** Average of per-class precisions. */
|
||||
private final double avgPrecision;
|
||||
|
||||
public Result(List<PerClassResult> classes, double avgPrecision) {
|
||||
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
|
||||
this.avgPrecision = avgPrecision;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
|
||||
this.avgPrecision = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public List<PerClassResult> getClasses() {
|
||||
return classes;
|
||||
}
|
||||
|
||||
public double getAvgPrecision() {
|
||||
return avgPrecision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeList(classes);
|
||||
out.writeDouble(avgPrecision);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASSES.getPreferredName(), classes);
|
||||
builder.field(AVG_PRECISION.getPreferredName(), avgPrecision);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.classes, that.classes)
|
||||
&& this.avgPrecision == that.avgPrecision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classes, avgPrecision);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PerClassResult implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
private static final ParseField PRECISION = new ParseField("precision");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||
new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareDouble(constructorArg(), PRECISION);
|
||||
}
|
||||
|
||||
/** Name of the class. */
|
||||
private final String className;
|
||||
/** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */
|
||||
private final double precision;
|
||||
|
||||
public PerClassResult(String className, double precision) {
|
||||
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
|
||||
this.precision = precision;
|
||||
}
|
||||
|
||||
public PerClassResult(StreamInput in) throws IOException {
|
||||
this.className = in.readString();
|
||||
this.precision = in.readDouble();
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public double getPrecision() {
|
||||
return precision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(className);
|
||||
out.writeDouble(precision);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(PRECISION.getPreferredName(), precision);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PerClassResult that = (PerClassResult) o;
|
||||
return Objects.equals(this.className, that.className)
|
||||
&& this.precision == that.precision;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, precision);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,321 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* {@link Recall} is a metric that answers the question:
|
||||
* "What fraction of documents belonging to X have been predicted as X by the classifier?"
|
||||
* for any given class X
|
||||
*
|
||||
* equation: recall(X) = TP(X) / (TP(X) + FN(X))
|
||||
* where: TP(X) - number of true positives wrt X
|
||||
* FN(X) - number of false negatives wrt X
|
||||
*/
|
||||
public class Recall implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("recall");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
||||
private static final String AGG_NAME_PREFIX = "classification_recall_";
|
||||
static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
|
||||
static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
|
||||
static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
|
||||
|
||||
private static Script buildScript(Object...args) {
|
||||
return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
||||
}
|
||||
|
||||
private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);
|
||||
|
||||
public static Recall fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
|
||||
|
||||
private final int maxClassesCardinality;
|
||||
private String actualField;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public Recall() {
|
||||
this((Integer) null);
|
||||
}
|
||||
|
||||
// Visible for testing
|
||||
public Recall(@Nullable Integer maxClassesCardinality) {
|
||||
this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
|
||||
}
|
||||
|
||||
public Recall(StreamInput in) throws IOException {
|
||||
this.maxClassesCardinality = in.readVInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
// Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
||||
this.actualField = actualField;
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
Script script = buildScript(actualField, predictedField);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
|
||||
.field(actualField)
|
||||
.size(maxClassesCardinality)
|
||||
.subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
|
||||
Arrays.asList(
|
||||
PipelineAggregatorBuilders.avgBucket(
|
||||
AVG_RECALL_AGG_NAME,
|
||||
BY_ACTUAL_CLASS_AGG_NAME + ">" + PER_ACTUAL_CLASS_RECALL_AGG_NAME)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
if (result == null &&
|
||||
aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
|
||||
aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
||||
Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
|
||||
if (byActualClassAgg.getSumOfOtherDocCounts() > 0) {
|
||||
// This means there were more than {@code maxClassesCardinality} buckets.
|
||||
// We cannot calculate average recall accurately, so we fail.
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
|
||||
}
|
||||
NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
|
||||
List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
|
||||
for (Terms.Bucket bucket : byActualClassAgg.getBuckets()) {
|
||||
String className = bucket.getKeyAsString();
|
||||
NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
|
||||
classes.add(new PerClassResult(className, recallAgg.value()));
|
||||
}
|
||||
result = new Result(classes, avgRecallAgg.value());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeVInt(maxClassesCardinality);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(NAME.getPreferredName());
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final ParseField CLASSES = new ParseField("classes");
|
||||
private static final ParseField AVG_RECALL = new ParseField("avg_recall");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("recall_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
||||
PARSER.declareDouble(constructorArg(), AVG_RECALL);
|
||||
}
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/** List of per-class results. */
|
||||
private final List<PerClassResult> classes;
|
||||
/** Average of per-class recalls. */
|
||||
private final double avgRecall;
|
||||
|
||||
public Result(List<PerClassResult> classes, double avgRecall) {
|
||||
this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
|
||||
this.avgRecall = avgRecall;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
|
||||
this.avgRecall = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Classification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public List<PerClassResult> getClasses() {
|
||||
return classes;
|
||||
}
|
||||
|
||||
public double getAvgRecall() {
|
||||
return avgRecall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeList(classes);
|
||||
out.writeDouble(avgRecall);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASSES.getPreferredName(), classes);
|
||||
builder.field(AVG_RECALL.getPreferredName(), avgRecall);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(this.classes, that.classes)
|
||||
&& this.avgRecall == that.avgRecall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(classes, avgRecall);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PerClassResult implements ToXContentObject, Writeable {
|
||||
|
||||
private static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
private static final ParseField RECALL = new ParseField("recall");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
||||
new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), CLASS_NAME);
|
||||
PARSER.declareDouble(constructorArg(), RECALL);
|
||||
}
|
||||
|
||||
/** Name of the class. */
|
||||
private final String className;
|
||||
/** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */
|
||||
private final double recall;
|
||||
|
||||
public PerClassResult(String className, double recall) {
|
||||
this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
|
||||
this.recall = recall;
|
||||
}
|
||||
|
||||
public PerClassResult(StreamInput in) throws IOException {
|
||||
this.className = in.readString();
|
||||
this.recall = in.readDouble();
|
||||
}
|
||||
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public double getRecall() {
|
||||
return recall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(className);
|
||||
out.writeDouble(recall);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(RECALL.getPreferredName(), recall);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
PerClassResult that = (PerClassResult) o;
|
||||
return Objects.equals(this.className, that.className)
|
||||
&& this.recall == that.recall;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, recall);
|
||||
}
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
@ -15,7 +16,9 @@ import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -27,12 +30,14 @@ import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Calculates the mean squared error between two known numerical fields.
|
||||
*
|
||||
* equation: mse = 1/n * Σ(y - y´)^2
|
||||
*/
|
||||
public class MeanSquaredError implements RegressionMetric {
|
||||
public class MeanSquaredError implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("mean_squared_error");
|
||||
|
||||
@ -62,11 +67,13 @@ public class MeanSquaredError implements RegressionMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
return Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))),
|
||||
Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -82,7 +89,7 @@ public class MeanSquaredError implements RegressionMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Regression.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -125,7 +132,7 @@ public class MeanSquaredError implements RegressionMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Regression.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -6,6 +6,7 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
@ -15,9 +16,11 @@ import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
|
||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -29,6 +32,8 @@ import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Calculates R-Squared between two known numerical fields.
|
||||
*
|
||||
@ -37,7 +42,7 @@ import java.util.Optional;
|
||||
* SSres = Σ(y - y´)^2, The residual sum of squares
|
||||
* SStot = Σ(y - y_mean)^2, The total sum of squares
|
||||
*/
|
||||
public class RSquared implements RegressionMetric {
|
||||
public class RSquared implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("r_squared");
|
||||
|
||||
@ -67,13 +72,15 @@ public class RSquared implements RegressionMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
|
||||
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(
|
||||
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
|
||||
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)),
|
||||
Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -97,7 +104,7 @@ public class RSquared implements RegressionMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Regression.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -140,7 +147,7 @@ public class RSquared implements RegressionMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(Regression.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -20,6 +21,8 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Evaluation of regression results.
|
||||
*/
|
||||
@ -33,13 +36,13 @@ public class Regression implements Evaluation {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<RegressionMetric>) a[2]));
|
||||
NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS);
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
|
||||
}
|
||||
|
||||
public static Regression fromXContent(XContentParser parser) {
|
||||
@ -61,22 +64,22 @@ public class Regression implements Evaluation {
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<RegressionMetric> metrics;
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
|
||||
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.metrics = initMetrics(metrics, Regression::defaultMetrics);
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> defaultMetrics() {
|
||||
private static List<EvaluationMetric> defaultMetrics() {
|
||||
return Arrays.asList(new MeanSquaredError(), new RSquared());
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedField = in.readString();
|
||||
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
|
||||
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -95,7 +98,7 @@ public class Regression implements Evaluation {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RegressionMetric> getMetrics() {
|
||||
public List<EvaluationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@ -118,8 +121,8 @@ public class Regression implements Evaluation {
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (RegressionMetric metric : metrics) {
|
||||
builder.field(metric.getWriteableName(), metric);
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
public interface RegressionMetric extends EvaluationMetric {
|
||||
}
|
@ -6,6 +6,7 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
@ -15,6 +16,8 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
@ -23,9 +26,9 @@ import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery;
|
||||
|
||||
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
|
||||
abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
|
||||
|
||||
public static final ParseField AT = new ParseField("at");
|
||||
|
||||
@ -63,11 +66,11 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
||||
}
|
||||
|
||||
@Override
|
||||
public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
return aggsAt(actualField, predictedProbabilityField);
|
||||
return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
@ -18,8 +19,10 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
|
||||
import org.elasticsearch.search.aggregations.metrics.Percentiles;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
@ -33,7 +36,8 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery;
|
||||
|
||||
/**
|
||||
* Area under the curve (AUC) of the receiver operating characteristic (ROC).
|
||||
@ -53,7 +57,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassific
|
||||
* When this is used for multi-class classification, it will calculate the ROC
|
||||
* curve of each class versus the rest.
|
||||
*/
|
||||
public class AucRoc implements SoftClassificationMetric {
|
||||
public class AucRoc implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("auc_roc");
|
||||
|
||||
@ -88,7 +92,7 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -123,9 +127,9 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
|
||||
if (result != null) {
|
||||
return Collections.emptyList();
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
|
||||
AggregationBuilder percentilesForClassValueAgg =
|
||||
@ -138,7 +142,9 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
|
||||
.subAggregation(
|
||||
AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
|
||||
return Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg);
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(percentilesForClassValueAgg, percentilesForRestAgg),
|
||||
Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -330,7 +336,7 @@ public class AucRoc implements SoftClassificationMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -12,7 +12,10 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -20,6 +23,8 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Evaluation of binary soft classification methods, e.g. outlier detection.
|
||||
* This is useful to evaluate problems where a model outputs a probability of whether
|
||||
@ -34,19 +39,23 @@ public class BinarySoftClassification implements Evaluation {
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
public static final ConstructingObjectParser<BinarySoftClassification, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List<SoftClassificationMetric>) a[2]));
|
||||
NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS);
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
|
||||
}
|
||||
|
||||
public static BinarySoftClassification fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
static QueryBuilder actualIsTrueQuery(String actualField) {
|
||||
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
|
||||
}
|
||||
|
||||
/**
|
||||
* The field where the actual class is marked up.
|
||||
* The value of this field is assumed to either be 1 or 0, or true or false.
|
||||
@ -61,16 +70,16 @@ public class BinarySoftClassification implements Evaluation {
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<SoftClassificationMetric> metrics;
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public BinarySoftClassification(String actualField, String predictedProbabilityField,
|
||||
@Nullable List<SoftClassificationMetric> metrics) {
|
||||
@Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
|
||||
this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics);
|
||||
}
|
||||
|
||||
private static List<SoftClassificationMetric> defaultMetrics() {
|
||||
private static List<EvaluationMetric> defaultMetrics() {
|
||||
return Arrays.asList(
|
||||
new AucRoc(false),
|
||||
new Precision(Arrays.asList(0.25, 0.5, 0.75)),
|
||||
@ -81,7 +90,7 @@ public class BinarySoftClassification implements Evaluation {
|
||||
public BinarySoftClassification(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedProbabilityField = in.readString();
|
||||
this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
|
||||
this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -100,7 +109,7 @@ public class BinarySoftClassification implements Evaluation {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SoftClassificationMetric> getMetrics() {
|
||||
public List<EvaluationMetric> getMetrics() {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
@ -123,7 +132,7 @@ public class BinarySoftClassification implements Evaluation {
|
||||
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (SoftClassificationMetric metric : metrics) {
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
|
@ -21,6 +21,8 @@ import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("confusion_matrix");
|
||||
@ -46,7 +48,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -129,7 +131,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -19,6 +19,8 @@ import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
public class Precision extends AbstractConfusionMatrixMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("precision");
|
||||
@ -44,7 +46,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -19,6 +19,8 @@ import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
public class Recall extends AbstractConfusionMatrixMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("recall");
|
||||
@ -44,7 +46,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -5,6 +5,7 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
@ -13,9 +14,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
public class ScoreByThresholdResult implements EvaluationMetricResult {
|
||||
|
||||
public static final String NAME = "score_by_threshold_result";
|
||||
public static final ParseField NAME = new ParseField("score_by_threshold_result");
|
||||
|
||||
private final String name;
|
||||
private final double[] thresholds;
|
||||
@ -36,7 +39,7 @@ public class ScoreByThresholdResult implements EvaluationMetricResult {
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
return registeredMetricName(BinarySoftClassification.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1,17 +0,0 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
|
||||
public interface SoftClassificationMetric extends EvaluationMetric {
|
||||
|
||||
static QueryBuilder actualIsTrueQuery(String actualField) {
|
||||
return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
|
||||
}
|
||||
}
|
@ -31,7 +31,7 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
||||
namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
namedWriteables.addAll(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
|
||||
return new NamedWriteableRegistry(namedWriteables);
|
||||
}
|
||||
@ -46,13 +46,11 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
Request request = new Request();
|
||||
int indicesCount = randomIntBetween(1, 5);
|
||||
List<String> indices = new ArrayList<>(indicesCount);
|
||||
for (int i = 0; i < indicesCount; i++) {
|
||||
indices.add(randomAlphaOfLength(10));
|
||||
}
|
||||
request.setIndices(indices);
|
||||
QueryProvider queryProvider = null;
|
||||
if (randomBoolean()) {
|
||||
try {
|
||||
@ -62,10 +60,11 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
request.setQueryProvider(queryProvider);
|
||||
Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
|
||||
request.setEvaluation(evaluation);
|
||||
return request;
|
||||
return new Request()
|
||||
.setIndices(indices)
|
||||
.setQueryProvider(queryProvider)
|
||||
.setEvaluation(evaluation);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -11,7 +11,10 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
|
||||
@ -22,7 +25,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -30,11 +33,13 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
|
||||
String evaluationName = randomAlphaOfLength(10);
|
||||
List<EvaluationMetricResult> metrics =
|
||||
Arrays.asList(
|
||||
AccuracyResultTests.createRandom(),
|
||||
PrecisionResultTests.createRandom(),
|
||||
RecallResultTests.createRandom(),
|
||||
MulticlassConfusionMatrixResultTests.createRandom(),
|
||||
new MeanSquaredError.Result(randomDouble()),
|
||||
new RSquared.Result(randomDouble()));
|
||||
int numMetrics = randomIntBetween(0, metrics.size());
|
||||
return new Response(evaluationName, metrics.subList(0, numMetrics));
|
||||
return new Response(evaluationName, randomSubsetOf(metrics));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -17,15 +17,9 @@ import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accuracy.Result> {
|
||||
public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Accuracy.Result createTestInstance() {
|
||||
public static Result createRandom() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<ActualClass> actualClasses = new ArrayList<>(numClasses);
|
||||
@ -38,7 +32,17 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accurac
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Accuracy.Result> instanceReader() {
|
||||
return Accuracy.Result::new;
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Result> instanceReader() {
|
||||
return Result::new;
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
@ -19,8 +20,10 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
@ -42,7 +45,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -51,10 +54,12 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||
}
|
||||
|
||||
public static Classification createRandom() {
|
||||
List<ClassificationMetric> metrics =
|
||||
List<EvaluationMetric> metrics =
|
||||
randomSubsetOf(
|
||||
Arrays.asList(
|
||||
AccuracyTests.createRandom(),
|
||||
PrecisionTests.createRandom(),
|
||||
RecallTests.createRandom(),
|
||||
MulticlassConfusionMatrixTests.createRandom()));
|
||||
return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
@ -101,10 +106,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||
}
|
||||
|
||||
public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() {
|
||||
ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
|
||||
ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
|
||||
ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
|
||||
ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
|
||||
EvaluationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
|
||||
EvaluationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
|
||||
EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
|
||||
EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
|
||||
|
||||
Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
|
||||
assertThat(metric1.getResult(), isEmpty());
|
||||
@ -168,7 +173,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||
* Number of steps is configurable.
|
||||
* Upon reaching the last step, the result is produced.
|
||||
*/
|
||||
private static class FakeClassificationMetric implements ClassificationMetric {
|
||||
private static class FakeClassificationMetric implements EvaluationMetric {
|
||||
|
||||
private final String name;
|
||||
private final int numSteps;
|
||||
@ -191,8 +196,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
return Collections.emptyList();
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -6,10 +6,12 @@
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
@ -25,9 +27,9 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregati
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
|
||||
@ -74,8 +76,8 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||
|
||||
public void testAggs() {
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
|
||||
List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
|
||||
assertThat(aggs, is(not(empty())));
|
||||
Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs("act", "pred");
|
||||
assertThat(aggs, isTuple(not(empty()), empty()));
|
||||
assertThat(confusionMatrix.getResult(), isEmpty());
|
||||
}
|
||||
|
||||
@ -109,7 +111,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
assertThat(
|
||||
@ -151,7 +153,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
|
||||
MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
|
||||
confusionMatrix.process(aggs);
|
||||
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
|
||||
assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
|
||||
assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
|
||||
assertThat(
|
||||
|
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.PerClassResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class PrecisionResultTests extends AbstractWireSerializingTestCase<Result> {
|
||||
|
||||
public static Result createRandom() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double precision = randomDoubleBetween(0.0, 1.0, true);
|
||||
classes.add(new PerClassResult(classNames.get(i), precision));
|
||||
}
|
||||
double avgPrecision = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(classes, avgPrecision);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Result> instanceReader() {
|
||||
return Result::new;
|
||||
}
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
|
||||
|
||||
@Override
|
||||
protected Precision doParseInstance(XContentParser parser) throws IOException {
|
||||
return Precision.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Precision createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Precision> instanceReader() {
|
||||
return Precision::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static Precision createRandom() {
|
||||
return new Precision();
|
||||
}
|
||||
|
||||
public void testProcess() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME),
|
||||
mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
|
||||
mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
|
||||
assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(precision.getResult().get(), equalTo(new Precision.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
public void testProcess_GivenMissingAgg() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
assertThat(precision.getResult(), isEmpty());
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
assertThat(precision.getResult(), isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenAggOfWrongType() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
|
||||
mockFilters(Precision.AVG_PRECISION_AGG_NAME)
|
||||
));
|
||||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
assertThat(precision.getResult(), isEmpty());
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue(Precision.BY_PREDICTED_CLASS_AGG_NAME, 1.0),
|
||||
mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123)
|
||||
));
|
||||
Precision precision = new Precision();
|
||||
precision.process(aggs);
|
||||
assertThat(precision.getResult(), isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenCardinalityTooHigh() {
|
||||
Aggregations aggs =
|
||||
new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
|
||||
Precision precision = new Precision();
|
||||
precision.aggs("foo", "bar");
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.PerClassResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class RecallResultTests extends AbstractWireSerializingTestCase<Result> {
|
||||
|
||||
public static Result createRandom() {
|
||||
int numClasses = randomIntBetween(2, 100);
|
||||
List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
|
||||
List<PerClassResult> classes = new ArrayList<>(numClasses);
|
||||
for (int i = 0; i < numClasses; i++) {
|
||||
double recall = randomDoubleBetween(0.0, 1.0, true);
|
||||
classes.add(new PerClassResult(classNames.get(i), recall));
|
||||
}
|
||||
double avgRecall = randomDoubleBetween(0.0, 1.0, true);
|
||||
return new Result(classes, avgRecall);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Result createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Result> instanceReader() {
|
||||
return Result::new;
|
||||
}
|
||||
}
|
@ -0,0 +1,117 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RecallTests extends AbstractSerializingTestCase<Recall> {
|
||||
|
||||
@Override
|
||||
protected Recall doParseInstance(XContentParser parser) throws IOException {
|
||||
return Recall.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Recall createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Recall> instanceReader() {
|
||||
return Recall::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static Recall createRandom() {
|
||||
return new Recall();
|
||||
}
|
||||
|
||||
public void testProcess() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
|
||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
|
||||
assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty()));
|
||||
assertThat(recall.getResult().get(), equalTo(new Recall.Result(Collections.emptyList(), 0.8123)));
|
||||
}
|
||||
|
||||
public void testProcess_GivenMissingAgg() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
assertThat(recall.getResult(), isEmpty());
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
assertThat(recall.getResult(), isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenAggOfWrongType() {
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
|
||||
mockTerms(Recall.AVG_RECALL_AGG_NAME)
|
||||
));
|
||||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
assertThat(recall.getResult(), isEmpty());
|
||||
}
|
||||
{
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue(Recall.BY_ACTUAL_CLASS_AGG_NAME, 1.0),
|
||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)
|
||||
));
|
||||
Recall recall = new Recall();
|
||||
recall.process(aggs);
|
||||
assertThat(recall.getResult(), isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenCardinalityTooHigh() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
|
||||
mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
|
||||
Recall recall = new Recall();
|
||||
recall.aggs("foo", "bar");
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
|
||||
}
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
||||
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.hamcrest.Description;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.hamcrest.TypeSafeMatcher;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class TupleMatchers {
|
||||
|
||||
private static class TupleMatcher<V1, V2> extends TypeSafeMatcher<Tuple<? extends V1, ? extends V2>> {
|
||||
|
||||
private final Matcher<? super V1> v1Matcher;
|
||||
private final Matcher<? super V2> v2Matcher;
|
||||
|
||||
private TupleMatcher(Matcher<? super V1> v1Matcher, Matcher<? super V2> v2Matcher) {
|
||||
this.v1Matcher = v1Matcher;
|
||||
this.v2Matcher = v2Matcher;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean matchesSafely(final Tuple<? extends V1, ? extends V2> item) {
|
||||
return item != null && v1Matcher.matches(item.v1()) && v2Matcher.matches(item.v2());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void describeTo(final Description description) {
|
||||
description.appendText("expected tuple matching ").appendList("[", ", ", "]", Arrays.asList(v1Matcher, v2Matcher));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a matcher that matches iff:
|
||||
* 1. the examined tuple's <code>v1()</code> matches the specified <code>v1Matcher</code>
|
||||
* and
|
||||
* 2. the examined tuple's <code>v2()</code> matches the specified <code>v2Matcher</code>
|
||||
* For example:
|
||||
* <pre>assertThat(Tuple.tuple("myValue1", "myValue2"), isTuple(startsWith("my"), containsString("Val")))</pre>
|
||||
*/
|
||||
public static <V1, V2> TupleMatcher<? extends V1, ? extends V2> isTuple(Matcher<? super V1> v1Matcher, Matcher<? super V2> v2Matcher) {
|
||||
return new TupleMatcher(v1Matcher, v2Matcher);
|
||||
}
|
||||
}
|
@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -29,7 +30,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -38,7 +39,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
List<RegressionMetric> metrics = new ArrayList<>();
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(MeanSquaredErrorTests.createRandom());
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -29,7 +30,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -38,7 +39,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
|
||||
}
|
||||
|
||||
public static BinarySoftClassification createRandom() {
|
||||
List<SoftClassificationMetric> metrics = new ArrayList<>();
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(AucRocTests.createRandom());
|
||||
}
|
||||
|
@ -5,22 +5,24 @@
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
@ -117,6 +119,69 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||
assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
precisionResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Precision.PerClassResult("ant", 1.0 / 15),
|
||||
new Precision.PerClassResult("cat", 1.0 / 15),
|
||||
new Precision.PerClassResult("dog", 1.0 / 15),
|
||||
new Precision.PerClassResult("fox", 1.0 / 15),
|
||||
new Precision.PerClassResult("mouse", 1.0 / 15))));
|
||||
assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Precision_CardinalityTooHigh() {
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX,
|
||||
new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Precision(4)))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
recallResult.getClasses(),
|
||||
equalTo(
|
||||
Arrays.asList(
|
||||
new Recall.PerClassResult("ant", 1.0 / 15),
|
||||
new Recall.PerClassResult("cat", 1.0 / 15),
|
||||
new Recall.PerClassResult("dog", 1.0 / 15),
|
||||
new Recall.PerClassResult("fox", 1.0 / 15),
|
||||
new Recall.PerClassResult("mouse", 1.0 / 15))));
|
||||
assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
|
||||
}
|
||||
|
||||
public void testEvaluate_Recall_CardinalityTooHigh() {
|
||||
ElasticsearchStatusException e =
|
||||
expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
() -> evaluateDataFrame(
|
||||
ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, Arrays.asList(new Recall(4)))));
|
||||
assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
|
||||
}
|
||||
|
||||
public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
@ -132,50 +197,50 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||
assertThat(
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new ActualClass("ant",
|
||||
new MulticlassConfusionMatrix.ActualClass("ant",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 1L),
|
||||
new PredictedClass("cat", 4L),
|
||||
new PredictedClass("dog", 3L),
|
||||
new PredictedClass("fox", 2L),
|
||||
new PredictedClass("mouse", 5L)),
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 1L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 4L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 3L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("fox", 2L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("mouse", 5L)),
|
||||
0),
|
||||
new ActualClass("cat",
|
||||
new MulticlassConfusionMatrix.ActualClass("cat",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 3L),
|
||||
new PredictedClass("cat", 1L),
|
||||
new PredictedClass("dog", 5L),
|
||||
new PredictedClass("fox", 4L),
|
||||
new PredictedClass("mouse", 2L)),
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 3L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 1L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 5L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("fox", 4L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("mouse", 2L)),
|
||||
0),
|
||||
new ActualClass("dog",
|
||||
new MulticlassConfusionMatrix.ActualClass("dog",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 4L),
|
||||
new PredictedClass("cat", 2L),
|
||||
new PredictedClass("dog", 1L),
|
||||
new PredictedClass("fox", 5L),
|
||||
new PredictedClass("mouse", 3L)),
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 4L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 2L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 1L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("fox", 5L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("mouse", 3L)),
|
||||
0),
|
||||
new ActualClass("fox",
|
||||
new MulticlassConfusionMatrix.ActualClass("fox",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 5L),
|
||||
new PredictedClass("cat", 3L),
|
||||
new PredictedClass("dog", 2L),
|
||||
new PredictedClass("fox", 1L),
|
||||
new PredictedClass("mouse", 4L)),
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 5L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 3L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 2L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("fox", 1L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("mouse", 4L)),
|
||||
0),
|
||||
new ActualClass("mouse",
|
||||
new MulticlassConfusionMatrix.ActualClass("mouse",
|
||||
15,
|
||||
Arrays.asList(
|
||||
new PredictedClass("ant", 2L),
|
||||
new PredictedClass("cat", 5L),
|
||||
new PredictedClass("dog", 4L),
|
||||
new PredictedClass("fox", 3L),
|
||||
new PredictedClass("mouse", 1L)),
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 2L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 5L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 4L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("fox", 3L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("mouse", 1L)),
|
||||
0))));
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
@ -194,17 +259,26 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
|
||||
assertThat(
|
||||
confusionMatrixResult.getConfusionMatrix(),
|
||||
equalTo(Arrays.asList(
|
||||
new ActualClass("ant",
|
||||
new MulticlassConfusionMatrix.ActualClass("ant",
|
||||
15,
|
||||
Arrays.asList(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)),
|
||||
Arrays.asList(
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 1L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 4L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 3L)),
|
||||
7),
|
||||
new ActualClass("cat",
|
||||
new MulticlassConfusionMatrix.ActualClass("cat",
|
||||
15,
|
||||
Arrays.asList(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)),
|
||||
Arrays.asList(
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 3L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 1L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 5L)),
|
||||
6),
|
||||
new ActualClass("dog",
|
||||
new MulticlassConfusionMatrix.ActualClass("dog",
|
||||
15,
|
||||
Arrays.asList(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)),
|
||||
Arrays.asList(
|
||||
new MulticlassConfusionMatrix.PredictedClass("ant", 4L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("cat", 2L),
|
||||
new MulticlassConfusionMatrix.PredictedClass("dog", 1L)),
|
||||
8))));
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
|
||||
}
|
||||
|
@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -450,9 +452,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
evaluateDataFrame(
|
||||
destIndex,
|
||||
new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
|
||||
dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix())));
|
||||
dependentVariable,
|
||||
predictedClassField,
|
||||
Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
|
||||
|
||||
{ // Accuracy
|
||||
Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
@ -483,6 +487,24 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
}
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
|
||||
{ // Precision
|
||||
Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
|
||||
assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
|
||||
for (Precision.PerClassResult klass : precisionResult.getClasses()) {
|
||||
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
|
||||
assertThat(klass.getPrecision(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
|
||||
}
|
||||
}
|
||||
|
||||
{ // Recall
|
||||
Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
|
||||
assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
|
||||
for (Recall.PerClassResult klass : recallResult.getClasses()) {
|
||||
assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
|
||||
assertThat(klass.getRecall(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected String stateDocId() {
|
||||
|
@ -632,6 +632,58 @@ setup:
|
||||
accuracy: 0.5 # 1 out of 2
|
||||
overall_accuracy: 0.625 # 5 out of 8
|
||||
---
|
||||
"Test classification precision":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { "precision": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match:
|
||||
classification.precision:
|
||||
classes:
|
||||
- class_name: "cat"
|
||||
precision: 0.5 # 2 out of 4
|
||||
- class_name: "dog"
|
||||
precision: 0.6666666666666666 # 2 out of 3
|
||||
- class_name: "mouse"
|
||||
precision: 1.0 # 1 out of 1
|
||||
avg_precision: 0.7222222222222222
|
||||
---
|
||||
"Test classification recall":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"classification": {
|
||||
"actual_field": "classification_field_act.keyword",
|
||||
"predicted_field": "classification_field_pred.keyword",
|
||||
"metrics": { "recall": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match:
|
||||
classification.recall:
|
||||
classes:
|
||||
- class_name: "cat"
|
||||
recall: 0.6666666666666666 # 2 out of 3
|
||||
- class_name: "dog"
|
||||
recall: 0.6666666666666666 # 2 out of 3
|
||||
- class_name: "mouse"
|
||||
recall: 0.5 # 1 out of 2
|
||||
avg_recall: 0.611111111111111
|
||||
---
|
||||
"Test classification multiclass_confusion_matrix":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
|
Loading…
x
Reference in New Issue
Block a user