* [ML] Adds support for regression.mean_squared_error to eval API * addressing PR comments * fixing tests
This commit is contained in:
parent
1636701d69
commit
c82d9c5b50
|
@ -18,6 +18,8 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
|
@ -38,12 +40,15 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
// Evaluations
|
||||
new NamedXContentRegistry.Entry(
|
||||
Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
|
||||
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
|
||||
// Evaluation metrics
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
|
||||
// Evaluation metrics results
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
|
||||
|
@ -51,6 +56,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* Calculates the mean squared error between two known numerical fields.
|
||||
*
|
||||
* equation: mse = 1/n * Σ(y - y´)^2
|
||||
*/
|
||||
public class MeanSquaredErrorMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "mean_squared_error";
|
||||
|
||||
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER =
|
||||
new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new);
|
||||
|
||||
public static MeanSquaredErrorMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public MeanSquaredErrorMetric() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
// create static hash code from name as there are currently no unique fields per class instance
|
||||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static final ParseField ERROR = new ParseField("error");
|
||||
private final double error;
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), ERROR);
|
||||
}
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR.getPreferredName(), error);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.error, this.error);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(error);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Evaluation of regression results.
|
||||
*/
|
||||
public class Regression implements Evaluation {
|
||||
|
||||
public static final String NAME = "regression";
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
|
||||
}
|
||||
|
||||
public static Regression fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public Regression(String actualField, String predictedField) {
|
||||
this(actualField, predictedField, (List<EvaluationMetric>)null);
|
||||
}
|
||||
|
||||
public Regression(String actualField, String predictedField, EvaluationMetric... metrics) {
|
||||
this(actualField, predictedField, Arrays.asList(metrics));
|
||||
}
|
||||
|
||||
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = actualField;
|
||||
this.predictedField = predictedField;
|
||||
this.metrics = metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
if (metrics != null) {
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Regression that = (Regression) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
}
|
||||
}
|
|
@ -123,6 +123,8 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
|
|||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
|
@ -1578,6 +1580,33 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
|
||||
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
|
||||
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
|
||||
|
||||
String regressionIndex = "evaluate-regression-test-index";
|
||||
createIndex(regressionIndex, mappingForRegression());
|
||||
BulkRequest regressionBulk = new BulkRequest()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.add(docForRegression(regressionIndex, 0.3, 0.1)) // #0
|
||||
.add(docForRegression(regressionIndex, 0.3, 0.2)) // #1
|
||||
.add(docForRegression(regressionIndex, 0.3, 0.3)) // #2
|
||||
.add(docForRegression(regressionIndex, 0.3, 0.4)) // #3
|
||||
.add(docForRegression(regressionIndex, 0.3, 0.7)) // #4
|
||||
.add(docForRegression(regressionIndex, 0.5, 0.2)) // #5
|
||||
.add(docForRegression(regressionIndex, 0.5, 0.3)) // #6
|
||||
.add(docForRegression(regressionIndex, 0.5, 0.4)) // #7
|
||||
.add(docForRegression(regressionIndex, 0.5, 0.8)) // #8
|
||||
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression));
|
||||
|
||||
evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
|
||||
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
|
||||
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
|
||||
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
|
||||
}
|
||||
|
||||
private static XContentBuilder defaultMappingForTest() throws IOException {
|
||||
|
@ -1615,6 +1644,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
|
||||
}
|
||||
|
||||
private static final String actualRegression = "regression_actual";
|
||||
private static final String probabilityRegression = "regression_prob";
|
||||
|
||||
private static XContentBuilder mappingForRegression() throws IOException {
|
||||
return XContentFactory.jsonBuilder().startObject()
|
||||
.startObject("properties")
|
||||
.startObject(actualRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.startObject(probabilityRegression)
|
||||
.field("type", "double")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
}
|
||||
|
||||
private static IndexRequest docForRegression(String indexName, double act, double p) {
|
||||
return new IndexRequest()
|
||||
.index(indexName)
|
||||
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
|
||||
}
|
||||
|
||||
private void createIndex(String indexName, XContentBuilder mapping) throws IOException {
|
||||
highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT);
|
||||
}
|
||||
|
|
|
@ -60,6 +60,8 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
|
|||
import org.elasticsearch.client.indexlifecycle.UnfollowAction;
|
||||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
|
||||
|
@ -674,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(31, namedXContents.size());
|
||||
assertEquals(34, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -712,12 +714,14 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
|
||||
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
|
||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME));
|
||||
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME));
|
||||
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
|
||||
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
|
||||
}
|
||||
|
||||
public void testApiNamingConventions() throws Exception {
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.elasticsearch.client.ml;
|
|||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
@ -45,6 +46,9 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<Eva
|
|||
if (randomBoolean()) {
|
||||
metrics.add(ConfusionMatrixMetricResultTests.randomResult());
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(MeanSquaredErrorMetricResultTests.randomResult());
|
||||
}
|
||||
return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class MeanSquaredErrorMetricResultTests extends AbstractXContentTestCase<MeanSquaredErrorMetric.Result> {
|
||||
|
||||
public static MeanSquaredErrorMetric.Result randomResult() {
|
||||
return new MeanSquaredErrorMetric.Result(randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredErrorMetric.Result createTestInstance() {
|
||||
return randomResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredErrorMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return MeanSquaredErrorMetric.Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class MeanSquaredErrorMetricTests extends AbstractXContentTestCase<MeanSquaredErrorMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredErrorMetric createTestInstance() {
|
||||
return new MeanSquaredErrorMetric();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredErrorMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return MeanSquaredErrorMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression createTestInstance() {
|
||||
return randomBoolean() ?
|
||||
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) :
|
||||
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric()));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||
return Regression.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
}
|
|
@ -8,6 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
|||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
|
||||
|
@ -28,6 +31,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
// Evaluations
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
|
||||
BinarySoftClassification::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
|
||||
|
||||
// Soft classification metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
|
||||
|
@ -36,6 +40,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
|
||||
ConfusionMatrix::fromXContent));
|
||||
|
||||
// Regression metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
||||
|
@ -45,6 +52,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
// Evaluations
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
||||
BinarySoftClassification::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
|
||||
|
||||
// Evaluation Metrics
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
|
||||
|
@ -55,6 +63,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
Recall::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError::new));
|
||||
|
||||
// Evaluation Metrics Results
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
|
||||
|
@ -63,6 +74,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
ScoreByThresholdResult::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
|
||||
ConfusionMatrix.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError.Result::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Calculates the mean squared error between two known numerical fields.
|
||||
*
|
||||
* equation: mse = 1/n * Σ(y - y´)^2
|
||||
*/
|
||||
public class MeanSquaredError implements RegressionMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("mean_squared_error");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
|
||||
private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
|
||||
|
||||
private static String buildScript(Object...args) {
|
||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||
}
|
||||
|
||||
private static final ObjectParser<MeanSquaredError, Void> PARSER =
|
||||
new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new);
|
||||
|
||||
public static MeanSquaredError fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public MeanSquaredError(StreamInput in) {
|
||||
|
||||
}
|
||||
|
||||
public MeanSquaredError() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
|
||||
return value == null ? null : new Result(value.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
// create static hash code from name as there are currently no unique fields per class instance
|
||||
return Objects.hashCode(NAME.getPreferredName());
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String ERROR = "error";
|
||||
private final double error;
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.error = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(error);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR, error);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Evaluation of regression results.
|
||||
*/
|
||||
public class Regression implements Evaluation {
|
||||
|
||||
public static final ParseField NAME = new ParseField("regression");
|
||||
|
||||
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
|
||||
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
|
||||
private static final ParseField METRICS = new ParseField("metrics");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
|
||||
NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<RegressionMetric>) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
|
||||
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
|
||||
(p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS);
|
||||
}
|
||||
|
||||
public static Regression fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* The field containing the actual value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String actualField;
|
||||
|
||||
/**
|
||||
* The field containing the predicted value
|
||||
* The value of this field is assumed to be numeric
|
||||
*/
|
||||
private final String predictedField;
|
||||
|
||||
/**
|
||||
* The list of metrics to calculate
|
||||
*/
|
||||
private final List<RegressionMetric> metrics;
|
||||
|
||||
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
|
||||
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
|
||||
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
|
||||
this.metrics = initMetrics(metrics);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
this.actualField = in.readString();
|
||||
this.predictedField = in.readString();
|
||||
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
|
||||
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
|
||||
if (metrics.isEmpty()) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
||||
}
|
||||
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private static List<RegressionMetric> defaultMetrics() {
|
||||
List<RegressionMetric> defaultMetrics = new ArrayList<>(1);
|
||||
defaultMetrics.add(new MeanSquaredError());
|
||||
return defaultMetrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchSourceBuilder buildSearch() {
|
||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||
.filter(QueryBuilders.existsQuery(actualField))
|
||||
.filter(QueryBuilders.existsQuery(predictedField));
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
||||
for (RegressionMetric metric : metrics) {
|
||||
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
|
||||
aggs.forEach(searchSourceBuilder::aggregation);
|
||||
}
|
||||
return searchSourceBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
|
||||
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
|
||||
for (RegressionMetric metric : metrics) {
|
||||
results.add(metric.evaluate(searchResponse.getAggregations()));
|
||||
}
|
||||
listener.onResponse(results);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(actualField);
|
||||
out.writeString(predictedField);
|
||||
out.writeNamedWriteableList(metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
|
||||
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (RegressionMetric metric : metrics) {
|
||||
builder.field(metric.getWriteableName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Regression that = (Regression) o;
|
||||
return Objects.equals(that.actualField, this.actualField)
|
||||
&& Objects.equals(that.predictedField, this.predictedField)
|
||||
&& Objects.equals(that.metrics, this.metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(actualField, predictedField, metrics);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteable;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface RegressionMetric extends ToXContentObject, NamedWriteable {
|
||||
|
||||
/**
|
||||
* Returns the name of the metric (which may differ to the writeable name)
|
||||
*/
|
||||
String getMetricName();
|
||||
|
||||
/**
|
||||
* Builds the aggregation that collect required data to compute the metric
|
||||
* @param actualField the field that stores the actual value
|
||||
* @param predictedField the field that stores the predicted value
|
||||
* @return the aggregations required to compute the metric
|
||||
*/
|
||||
List<AggregationBuilder> aggs(String actualField, String predictedField);
|
||||
|
||||
/**
|
||||
* Calculates the metric result
|
||||
* @param aggs the aggregations
|
||||
* @return the metric result
|
||||
*/
|
||||
EvaluationMetricResult evaluate(Aggregations aggs);
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
|
||||
|
||||
@Override
|
||||
protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException {
|
||||
return MeanSquaredError.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredError createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<MeanSquaredError> instanceReader() {
|
||||
return MeanSquaredError::new;
|
||||
}
|
||||
|
||||
public static MeanSquaredError createRandom() {
|
||||
return new MeanSquaredError();
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("regression_mean_squared_error", 0.8123),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
MeanSquaredError mse = new MeanSquaredError();
|
||||
EvaluationMetricResult result = mse.evaluate(aggs);
|
||||
|
||||
String expected = "{\"error\":0.8123}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingAggs() {
|
||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
MeanSquaredError mse = new MeanSquaredError();
|
||||
EvaluationMetricResult result = mse.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
}
|
||||
|
||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
||||
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
||||
when(agg.getName()).thenReturn(name);
|
||||
when(agg.value()).thenReturn(value);
|
||||
return agg;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
@Override
|
||||
protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
List<RegressionMetric> metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom());
|
||||
return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression doParseInstance(XContentParser parser) throws IOException {
|
||||
return Regression.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Regression createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Regression> instanceReader() {
|
||||
return Regression::new;
|
||||
}
|
||||
|
||||
public void testConstructor_GivenEmptyMetrics() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", "bar", Collections.emptyList()));
|
||||
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
|
||||
}
|
||||
}
|
|
@ -72,9 +72,9 @@ integTest.runner {
|
|||
'ml/evaluate_data_frame/Test given missing index',
|
||||
'ml/evaluate_data_frame/Test given index does not exist',
|
||||
'ml/evaluate_data_frame/Test given missing evaluation',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always true',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always false',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always true',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always false',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with empty metrics',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero',
|
||||
|
@ -83,6 +83,7 @@ integTest.runner {
|
|||
'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
|
||||
'ml/delete_job_force/Test cannot force delete a non-existent job',
|
||||
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
|
||||
'ml/delete_model_snapshot/Test delete snapshot missing job_id',
|
||||
|
|
|
@ -8,6 +8,8 @@ setup:
|
|||
"is_outlier": false,
|
||||
"is_outlier_int": 0,
|
||||
"outlier_score": 0.0,
|
||||
"regression_field_act": 10.9,
|
||||
"regression_field_pred": 10.9,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -20,6 +22,8 @@ setup:
|
|||
"is_outlier": false,
|
||||
"is_outlier_int": 0,
|
||||
"outlier_score": 0.2,
|
||||
"regression_field_act": 12.0,
|
||||
"regression_field_pred": 9.9,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -32,6 +36,8 @@ setup:
|
|||
"is_outlier": false,
|
||||
"is_outlier_int": 0,
|
||||
"outlier_score": 0.3,
|
||||
"regression_field_act": 20.9,
|
||||
"regression_field_pred": 5.9,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -44,6 +50,8 @@ setup:
|
|||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.3,
|
||||
"regression_field_act": 11.9,
|
||||
"regression_field_pred": 11.9,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -56,6 +64,8 @@ setup:
|
|||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.4,
|
||||
"regression_field_act": 42.9,
|
||||
"regression_field_pred": 42.9,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -68,6 +78,8 @@ setup:
|
|||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.5,
|
||||
"regression_field_act": 0.42,
|
||||
"regression_field_pred": 0.42,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -80,6 +92,8 @@ setup:
|
|||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.9,
|
||||
"regression_field_act": 1.1235813,
|
||||
"regression_field_pred": 1.12358,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -92,6 +106,8 @@ setup:
|
|||
"is_outlier": true,
|
||||
"is_outlier_int": 1,
|
||||
"outlier_score": 0.95,
|
||||
"regression_field_act": -5.20,
|
||||
"regression_field_pred": -5.1,
|
||||
"all_true_field": true,
|
||||
"all_false_field": false
|
||||
}
|
||||
|
@ -109,7 +125,7 @@ setup:
|
|||
indices.refresh: {}
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition auc_roc":
|
||||
"Test binary_soft_classification auc_roc":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -129,7 +145,7 @@ setup:
|
|||
- is_false: binary_soft_classification.auc_roc.curve
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition auc_roc given actual_field is int":
|
||||
"Test binary_soft_classification auc_roc given actual_field is int":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -149,7 +165,7 @@ setup:
|
|||
- is_false: binary_soft_classification.auc_roc.curve
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition auc_roc include curve":
|
||||
"Test binary_soft_classification auc_roc include curve":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -169,7 +185,7 @@ setup:
|
|||
- is_true: binary_soft_classification.auc_roc.curve
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition auc_roc given actual_field is always true":
|
||||
"Test binary_soft_classification auc_roc given actual_field is always true":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/
|
||||
ml.evaluate_data_frame:
|
||||
|
@ -188,7 +204,7 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition auc_roc given actual_field is always false":
|
||||
"Test binary_soft_classification auc_roc given actual_field is always false":
|
||||
- do:
|
||||
catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/
|
||||
ml.evaluate_data_frame:
|
||||
|
@ -207,7 +223,7 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition precision":
|
||||
"Test binary_soft_classification precision":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -230,7 +246,7 @@ setup:
|
|||
'0.5': 1.0
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition recall":
|
||||
"Test binary_soft_classification recall":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -254,7 +270,7 @@ setup:
|
|||
'0.5': 0.6
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition confusion_matrix":
|
||||
"Test binary_soft_classification confusion_matrix":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -290,7 +306,7 @@ setup:
|
|||
fn: 2
|
||||
|
||||
---
|
||||
"Test binary_soft_classifition default metrics":
|
||||
"Test binary_soft_classification default metrics":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
|
@ -356,7 +372,7 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test binary_soft_classification given evaluation with emtpy metrics":
|
||||
"Test binary_soft_classification given evaluation with empty metrics":
|
||||
- do:
|
||||
catch: /\[binary_soft_classification\] must have one or more metrics/
|
||||
ml.evaluate_data_frame:
|
||||
|
@ -518,3 +534,52 @@ setup:
|
|||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test regression given evaluation with empty metrics":
|
||||
- do:
|
||||
catch: /\[regression\] must have one or more metrics/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "regression_field_act",
|
||||
"predicted_field": "regression_field_pred",
|
||||
"metrics": { }
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test regression mean_squared_error":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "regression_field_act",
|
||||
"predicted_field": "regression_field_pred",
|
||||
"metrics": { "mean_squared_error": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
---
|
||||
"Test regression with null metrics":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "regression_field_act",
|
||||
"predicted_field": "regression_field_pred"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
|
|
Loading…
Reference in New Issue