[ML] Adds support for regression.mean_squared_error to eval API (#44140) (#44218)

* [ML] Adds support for regression.mean_squared_error to eval API

* addressing PR comments

* fixing tests
This commit is contained in:
Benjamin Trent 2019-07-11 09:22:52 -05:00 committed by GitHub
parent 1636701d69
commit c82d9c5b50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1069 additions and 20 deletions

View File

@ -18,6 +18,8 @@
*/ */
package org.elasticsearch.client.ml.dataframe.evaluation; package org.elasticsearch.client.ml.dataframe.evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -38,12 +40,15 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluations // Evaluations
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
// Evaluation metrics // Evaluation metrics
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
// Evaluation metrics results // Evaluation metrics results
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
@ -51,6 +56,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
} }

View File

@ -0,0 +1,129 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
/**
* Calculates the mean squared error between two known numerical fields.
*
* equation: mse = 1/n * Σ(y - y´)^2
*/
public class MeanSquaredErrorMetric implements EvaluationMetric {
public static final String NAME = "mean_squared_error";
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER =
new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new);
public static MeanSquaredErrorMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public MeanSquaredErrorMetric() {
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
// create static hash code from name as there are currently no unique fields per class instance
return Objects.hashCode(NAME);
}
@Override
public String getName() {
return NAME;
}
public static class Result implements EvaluationMetric.Result {
public static final ParseField ERROR = new ParseField("error");
private final double error;
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), ERROR);
}
public Result(double error) {
this.error = error;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(ERROR.getPreferredName(), error);
builder.endObject();
return builder;
}
public double getError() {
return error;
}
@Override
public String getMetricName() {
return NAME;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.error, this.error);
}
@Override
public int hashCode() {
return Objects.hash(error);
}
}
}

View File

@ -0,0 +1,129 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
/**
* Evaluation of regression results.
*/
public class Regression implements Evaluation {
public static final String NAME = "regression";
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField METRICS = new ParseField("metrics");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
}
public static Regression fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/**
* The field containing the actual value
* The value of this field is assumed to be numeric
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
*/
private final String predictedField;
/**
* The list of metrics to calculate
*/
private final List<EvaluationMetric> metrics;
public Regression(String actualField, String predictedField) {
this(actualField, predictedField, (List<EvaluationMetric>)null);
}
public Regression(String actualField, String predictedField, EvaluationMetric... metrics) {
this(actualField, predictedField, Arrays.asList(metrics));
}
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
this.actualField = actualField;
this.predictedField = predictedField;
this.metrics = metrics;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
if (metrics != null) {
builder.startObject(METRICS.getPreferredName());
for (EvaluationMetric metric : metrics) {
builder.field(metric.getName(), metric);
}
builder.endObject();
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.metrics, this.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedField, metrics);
}
}

View File

@ -123,6 +123,8 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@ -1578,6 +1580,33 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
String regressionIndex = "evaluate-regression-test-index";
createIndex(regressionIndex, mappingForRegression());
BulkRequest regressionBulk = new BulkRequest()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(docForRegression(regressionIndex, 0.3, 0.1)) // #0
.add(docForRegression(regressionIndex, 0.3, 0.2)) // #1
.add(docForRegression(regressionIndex, 0.3, 0.3)) // #2
.add(docForRegression(regressionIndex, 0.3, 0.4)) // #3
.add(docForRegression(regressionIndex, 0.3, 0.7)) // #4
.add(docForRegression(regressionIndex, 0.5, 0.2)) // #5
.add(docForRegression(regressionIndex, 0.5, 0.3)) // #6
.add(docForRegression(regressionIndex, 0.5, 0.4)) // #7
.add(docForRegression(regressionIndex, 0.5, 0.8)) // #8
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression));
evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
} }
private static XContentBuilder defaultMappingForTest() throws IOException { private static XContentBuilder defaultMappingForTest() throws IOException {
@ -1615,6 +1644,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
} }
private static final String actualRegression = "regression_actual";
private static final String probabilityRegression = "regression_prob";
private static XContentBuilder mappingForRegression() throws IOException {
return XContentFactory.jsonBuilder().startObject()
.startObject("properties")
.startObject(actualRegression)
.field("type", "double")
.endObject()
.startObject(probabilityRegression)
.field("type", "double")
.endObject()
.endObject()
.endObject();
}
private static IndexRequest docForRegression(String indexName, double act, double p) {
return new IndexRequest()
.index(indexName)
.source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
}
private void createIndex(String indexName, XContentBuilder mapping) throws IOException { private void createIndex(String indexName, XContentBuilder mapping) throws IOException {
highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT); highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT);
} }

View File

@ -60,6 +60,8 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@ -674,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() { public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents(); List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(31, namedXContents.size()); assertEquals(34, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>(); Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>(); List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) { for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -712,12 +714,14 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
assertTrue(names.contains(TimeSyncConfig.NAME)); assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME)); assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); assertThat(names,
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names,
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
} }
public void testApiNamingConventions() throws Exception { public void testApiNamingConventions() throws Exception {

View File

@ -20,6 +20,7 @@ package org.elasticsearch.client.ml;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.AbstractXContentTestCase;
@ -45,6 +46,9 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<Eva
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(ConfusionMatrixMetricResultTests.randomResult()); metrics.add(ConfusionMatrixMetricResultTests.randomResult());
} }
if (randomBoolean()) {
metrics.add(MeanSquaredErrorMetricResultTests.randomResult());
}
return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics);
} }

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class MeanSquaredErrorMetricResultTests extends AbstractXContentTestCase<MeanSquaredErrorMetric.Result> {
public static MeanSquaredErrorMetric.Result randomResult() {
return new MeanSquaredErrorMetric.Result(randomDouble());
}
@Override
protected MeanSquaredErrorMetric.Result createTestInstance() {
return randomResult();
}
@Override
protected MeanSquaredErrorMetric.Result doParseInstance(XContentParser parser) throws IOException {
return MeanSquaredErrorMetric.Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class MeanSquaredErrorMetricTests extends AbstractXContentTestCase<MeanSquaredErrorMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected MeanSquaredErrorMetric createTestInstance() {
return new MeanSquaredErrorMetric();
}
@Override
protected MeanSquaredErrorMetric doParseInstance(XContentParser parser) throws IOException {
return MeanSquaredErrorMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.Collections;
import java.util.function.Predicate;
public class RegressionTests extends AbstractXContentTestCase<Regression> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected Regression createTestInstance() {
return randomBoolean() ?
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) :
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric()));
}
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// allow unknown fields in the root of the object only
return field -> !field.isEmpty();
}
}

View File

@ -8,6 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
@ -28,6 +31,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluations // Evaluations
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
BinarySoftClassification::fromXContent)); BinarySoftClassification::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
// Soft classification metrics // Soft classification metrics
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
@ -36,6 +40,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
ConfusionMatrix::fromXContent)); ConfusionMatrix::fromXContent));
// Regression metrics
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
return namedXContent; return namedXContent;
} }
@ -45,6 +52,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
// Evaluations // Evaluations
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
BinarySoftClassification::new)); BinarySoftClassification::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
// Evaluation Metrics // Evaluation Metrics
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
@ -55,6 +63,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
Recall::new)); Recall::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix::new)); ConfusionMatrix::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError::new));
// Evaluation Metrics Results // Evaluation Metrics Results
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
@ -63,6 +74,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
ScoreByThresholdResult::new)); ScoreByThresholdResult::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
ConfusionMatrix.Result::new)); ConfusionMatrix.Result::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
MeanSquaredError.NAME.getPreferredName(),
MeanSquaredError.Result::new));
return namedWriteables; return namedWriteables;
} }

View File

@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
/**
* Calculates the mean squared error between two known numerical fields.
*
* equation: mse = 1/n * Σ(y - y´)^2
*/
public class MeanSquaredError implements RegressionMetric {
public static final ParseField NAME = new ParseField("mean_squared_error");
private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
}
private static final ObjectParser<MeanSquaredError, Void> PARSER =
new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new);
public static MeanSquaredError fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public MeanSquaredError(StreamInput in) {
}
public MeanSquaredError() {
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
@Override
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
}
@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
return value == null ? null : new Result(value.value());
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
// create static hash code from name as there are currently no unique fields per class instance
return Objects.hashCode(NAME.getPreferredName());
}
public static class Result implements EvaluationMetricResult {
private static final String ERROR = "error";
private final double error;
public Result(double error) {
this.error = error;
}
public Result(StreamInput in) throws IOException {
this.error = in.readDouble();
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(error);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ERROR, error);
builder.endObject();
return builder;
}
}
}

View File

@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
/**
* Evaluation of regression results.
*/
public class Regression implements Evaluation {
public static final ParseField NAME = new ParseField("regression");
private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
private static final ParseField METRICS = new ParseField("metrics");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<RegressionMetric>) a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS);
}
public static Regression fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
/**
* The field containing the actual value
* The value of this field is assumed to be numeric
*/
private final String actualField;
/**
* The field containing the predicted value
* The value of this field is assumed to be numeric
*/
private final String predictedField;
/**
* The list of metrics to calculate
*/
private final List<RegressionMetric> metrics;
public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
this.metrics = initMetrics(metrics);
}
public Regression(StreamInput in) throws IOException {
this.actualField = in.readString();
this.predictedField = in.readString();
this.metrics = in.readNamedWriteableList(RegressionMetric.class);
}
private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
if (metrics.isEmpty()) {
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
}
Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName));
return metrics;
}
private static List<RegressionMetric> defaultMetrics() {
List<RegressionMetric> defaultMetrics = new ArrayList<>(1);
defaultMetrics.add(new MeanSquaredError());
return defaultMetrics;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public SearchSourceBuilder buildSearch() {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery(actualField))
.filter(QueryBuilders.existsQuery(predictedField));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
for (RegressionMetric metric : metrics) {
List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
aggs.forEach(searchSourceBuilder::aggregation);
}
return searchSourceBuilder;
}
@Override
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
for (RegressionMetric metric : metrics) {
results.add(metric.evaluate(searchResponse.getAggregations()));
}
listener.onResponse(results);
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(actualField);
out.writeString(predictedField);
out.writeNamedWriteableList(metrics);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
builder.startObject(METRICS.getPreferredName());
for (RegressionMetric metric : metrics) {
builder.field(metric.getWriteableName(), metric);
}
builder.endObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Regression that = (Regression) o;
return Objects.equals(that.actualField, this.actualField)
&& Objects.equals(that.predictedField, this.predictedField)
&& Objects.equals(that.metrics, this.metrics);
}
@Override
public int hashCode() {
return Objects.hash(actualField, predictedField, metrics);
}
}

View File

@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.util.List;
public interface RegressionMetric extends ToXContentObject, NamedWriteable {
/**
* Returns the name of the metric (which may differ to the writeable name)
*/
String getMetricName();
/**
* Builds the aggregation that collect required data to compute the metric
* @param actualField the field that stores the actual value
* @param predictedField the field that stores the predicted value
* @return the aggregations required to compute the metric
*/
List<AggregationBuilder> aggs(String actualField, String predictedField);
/**
* Calculates the metric result
* @param aggs the aggregations
* @return the metric result
*/
EvaluationMetricResult evaluate(Aggregations aggs);
}

View File

@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
@Override
protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException {
return MeanSquaredError.fromXContent(parser);
}
@Override
protected MeanSquaredError createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<MeanSquaredError> instanceReader() {
return MeanSquaredError::new;
}
public static MeanSquaredError createRandom() {
return new MeanSquaredError();
}
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("regression_mean_squared_error", 0.8123),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
MeanSquaredError mse = new MeanSquaredError();
EvaluationMetricResult result = mse.evaluate(aggs);
String expected = "{\"error\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}
public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
MeanSquaredError mse = new MeanSquaredError();
EvaluationMetricResult result = mse.evaluate(aggs);
assertThat(result, is(nullValue()));
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
when(agg.getName()).thenReturn(name);
when(agg.value()).thenReturn(value);
return agg;
}
}

View File

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
public static Regression createRandom() {
List<RegressionMetric> metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom());
return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics);
}
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser);
}
@Override
protected Regression createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<Regression> instanceReader() {
return Regression::new;
}
public void testConstructor_GivenEmptyMetrics() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", "bar", Collections.emptyList()));
assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
}
}

View File

@ -72,9 +72,9 @@ integTest.runner {
'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given missing index',
'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given index does not exist',
'ml/evaluate_data_frame/Test given missing evaluation', 'ml/evaluate_data_frame/Test given missing evaluation',
'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always true', 'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always true',
'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always false', 'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always false',
'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics', 'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with empty metrics',
'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field', 'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field',
'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field', 'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field',
'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero', 'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero',
@ -83,6 +83,7 @@ integTest.runner {
'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds',
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_job_force/Test cannot force delete a non-existent job',
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
'ml/delete_model_snapshot/Test delete snapshot missing job_id', 'ml/delete_model_snapshot/Test delete snapshot missing job_id',

View File

@ -8,6 +8,8 @@ setup:
"is_outlier": false, "is_outlier": false,
"is_outlier_int": 0, "is_outlier_int": 0,
"outlier_score": 0.0, "outlier_score": 0.0,
"regression_field_act": 10.9,
"regression_field_pred": 10.9,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -20,6 +22,8 @@ setup:
"is_outlier": false, "is_outlier": false,
"is_outlier_int": 0, "is_outlier_int": 0,
"outlier_score": 0.2, "outlier_score": 0.2,
"regression_field_act": 12.0,
"regression_field_pred": 9.9,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -32,6 +36,8 @@ setup:
"is_outlier": false, "is_outlier": false,
"is_outlier_int": 0, "is_outlier_int": 0,
"outlier_score": 0.3, "outlier_score": 0.3,
"regression_field_act": 20.9,
"regression_field_pred": 5.9,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -44,6 +50,8 @@ setup:
"is_outlier": true, "is_outlier": true,
"is_outlier_int": 1, "is_outlier_int": 1,
"outlier_score": 0.3, "outlier_score": 0.3,
"regression_field_act": 11.9,
"regression_field_pred": 11.9,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -56,6 +64,8 @@ setup:
"is_outlier": true, "is_outlier": true,
"is_outlier_int": 1, "is_outlier_int": 1,
"outlier_score": 0.4, "outlier_score": 0.4,
"regression_field_act": 42.9,
"regression_field_pred": 42.9,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -68,6 +78,8 @@ setup:
"is_outlier": true, "is_outlier": true,
"is_outlier_int": 1, "is_outlier_int": 1,
"outlier_score": 0.5, "outlier_score": 0.5,
"regression_field_act": 0.42,
"regression_field_pred": 0.42,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -80,6 +92,8 @@ setup:
"is_outlier": true, "is_outlier": true,
"is_outlier_int": 1, "is_outlier_int": 1,
"outlier_score": 0.9, "outlier_score": 0.9,
"regression_field_act": 1.1235813,
"regression_field_pred": 1.12358,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -92,6 +106,8 @@ setup:
"is_outlier": true, "is_outlier": true,
"is_outlier_int": 1, "is_outlier_int": 1,
"outlier_score": 0.95, "outlier_score": 0.95,
"regression_field_act": -5.20,
"regression_field_pred": -5.1,
"all_true_field": true, "all_true_field": true,
"all_false_field": false "all_false_field": false
} }
@ -109,7 +125,7 @@ setup:
indices.refresh: {} indices.refresh: {}
--- ---
"Test binary_soft_classifition auc_roc": "Test binary_soft_classification auc_roc":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -129,7 +145,7 @@ setup:
- is_false: binary_soft_classification.auc_roc.curve - is_false: binary_soft_classification.auc_roc.curve
--- ---
"Test binary_soft_classifition auc_roc given actual_field is int": "Test binary_soft_classification auc_roc given actual_field is int":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -149,7 +165,7 @@ setup:
- is_false: binary_soft_classification.auc_roc.curve - is_false: binary_soft_classification.auc_roc.curve
--- ---
"Test binary_soft_classifition auc_roc include curve": "Test binary_soft_classification auc_roc include curve":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -169,7 +185,7 @@ setup:
- is_true: binary_soft_classification.auc_roc.curve - is_true: binary_soft_classification.auc_roc.curve
--- ---
"Test binary_soft_classifition auc_roc given actual_field is always true": "Test binary_soft_classification auc_roc given actual_field is always true":
- do: - do:
catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/ catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/
ml.evaluate_data_frame: ml.evaluate_data_frame:
@ -188,7 +204,7 @@ setup:
} }
--- ---
"Test binary_soft_classifition auc_roc given actual_field is always false": "Test binary_soft_classification auc_roc given actual_field is always false":
- do: - do:
catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/ catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/
ml.evaluate_data_frame: ml.evaluate_data_frame:
@ -207,7 +223,7 @@ setup:
} }
--- ---
"Test binary_soft_classifition precision": "Test binary_soft_classification precision":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -230,7 +246,7 @@ setup:
'0.5': 1.0 '0.5': 1.0
--- ---
"Test binary_soft_classifition recall": "Test binary_soft_classification recall":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -254,7 +270,7 @@ setup:
'0.5': 0.6 '0.5': 0.6
--- ---
"Test binary_soft_classifition confusion_matrix": "Test binary_soft_classification confusion_matrix":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -290,7 +306,7 @@ setup:
fn: 2 fn: 2
--- ---
"Test binary_soft_classifition default metrics": "Test binary_soft_classification default metrics":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -356,7 +372,7 @@ setup:
} }
--- ---
"Test binary_soft_classification given evaluation with emtpy metrics": "Test binary_soft_classification given evaluation with empty metrics":
- do: - do:
catch: /\[binary_soft_classification\] must have one or more metrics/ catch: /\[binary_soft_classification\] must have one or more metrics/
ml.evaluate_data_frame: ml.evaluate_data_frame:
@ -518,3 +534,52 @@ setup:
} }
} }
} }
---
"Test regression given evaluation with empty metrics":
- do:
catch: /\[regression\] must have one or more metrics/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { }
}
}
}
---
"Test regression mean_squared_error":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { "mean_squared_error": {} }
}
}
}
- match: { regression.mean_squared_error.error: 28.67749840974834 }
---
"Test regression with null metrics":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred"
}
}
}
- match: { regression.mean_squared_error.error: 28.67749840974834 }