diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 764ff41de86..b6f07fd4949 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -38,12 +40,15 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluations new NamedXContentRegistry.Entry( Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent), // Evaluation metrics new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent), // Evaluation metrics results new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), @@ -51,6 +56,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java new file mode 100644 index 00000000000..5b961dacbcc --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * Calculates the mean squared error between two known numerical fields. + * + * equation: mse = 1/n * Σ(y - y´)^2 + */ +public class MeanSquaredErrorMetric implements EvaluationMetric { + + public static final String NAME = "mean_squared_error"; + + private static final ObjectParser PARSER = + new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new); + + public static MeanSquaredErrorMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public MeanSquaredErrorMetric() { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // create static hash code from name as there are currently no unique fields per class instance + return Objects.hashCode(NAME); + } + + @Override + public String getName() { + return NAME; + } + + public static class Result implements EvaluationMetric.Result { + + public static final ParseField ERROR = new ParseField("error"); + private final double error; + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0])); + + static { + PARSER.declareDouble(constructorArg(), ERROR); + } + + public Result(double error) { + this.error = error; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ERROR.getPreferredName(), error); + builder.endObject(); + return builder; + } + + public double getError() { + return error; + } + + @Override + public String getMetricName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(that.error, this.error); + } + + @Override + public int hashCode() { + return Objects.hash(error); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java new file mode 100644 index 00000000000..13b14f6e0b0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of regression results. + */ +public class Regression implements Evaluation { + + public static final String NAME = "regression"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, true, a -> new Regression((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS); + } + + public static Regression fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Regression(String actualField, String predictedField) { + this(actualField, predictedField, (List)null); + } + + public Regression(String actualField, String predictedField, EvaluationMetric... metrics) { + this(actualField, predictedField, Arrays.asList(metrics)); + } + + public Regression(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = actualField; + this.predictedField = predictedField; + this.metrics = metrics; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + if (metrics != null) { + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + } + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index b542db9c4b0..d99d9ecd29d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -123,6 +123,8 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; @@ -1578,6 +1580,33 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); + + String regressionIndex = "evaluate-regression-test-index"; + createIndex(regressionIndex, mappingForRegression()); + BulkRequest regressionBulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForRegression(regressionIndex, 0.3, 0.1)) // #0 + .add(docForRegression(regressionIndex, 0.3, 0.2)) // #1 + .add(docForRegression(regressionIndex, 0.3, 0.3)) // #2 + .add(docForRegression(regressionIndex, 0.3, 0.4)) // #3 + .add(docForRegression(regressionIndex, 0.3, 0.7)) // #4 + .add(docForRegression(regressionIndex, 0.5, 0.2)) // #5 + .add(docForRegression(regressionIndex, 0.5, 0.3)) // #6 + .add(docForRegression(regressionIndex, 0.5, 0.4)) // #7 + .add(docForRegression(regressionIndex, 0.5, 0.8)) // #8 + .add(docForRegression(regressionIndex, 0.5, 0.9)); // #9 + highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); + + evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression)); + + evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + + MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME); + assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME)); + assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9)); } private static XContentBuilder defaultMappingForTest() throws IOException { @@ -1615,6 +1644,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); } + private static final String actualRegression = "regression_actual"; + private static final String probabilityRegression = "regression_prob"; + + private static XContentBuilder mappingForRegression() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualRegression) + .field("type", "double") + .endObject() + .startObject(probabilityRegression) + .field("type", "double") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForRegression(String indexName, double act, double p) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualRegression, act, probabilityRegression, p); + } + private void createIndex(String indexName, XContentBuilder mapping) throws IOException { highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index ae1cd5eb45e..77dc9ee53fd 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -60,6 +60,8 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; @@ -674,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(31, namedXContents.size()); + assertEquals(34, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -712,12 +714,14 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); - assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); - assertThat(names, hasItems(BinarySoftClassification.NAME)); - assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); - assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); - assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); - assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); + assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME)); + assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertThat(names, + hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME)); + assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertThat(names, + hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java index b41d113686c..70740a3268f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.client.ml; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; @@ -45,6 +46,9 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase { + + public static MeanSquaredErrorMetric.Result randomResult() { + return new MeanSquaredErrorMetric.Result(randomDouble()); + } + + @Override + protected MeanSquaredErrorMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected MeanSquaredErrorMetric.Result doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredErrorMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java new file mode 100644 index 00000000000..9027462b21e --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class MeanSquaredErrorMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected MeanSquaredErrorMetric createTestInstance() { + return new MeanSquaredErrorMetric(); + } + + @Override + protected MeanSquaredErrorMetric doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredErrorMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java new file mode 100644 index 00000000000..f5b3db9cec8 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java @@ -0,0 +1,59 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Predicate; + +public class RegressionTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected Regression createTestInstance() { + return randomBoolean() ? + new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) : + new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric())); + } + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index f4a6dba88e3..f713aa0033d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; @@ -28,6 +31,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluations namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent)); // Soft classification metrics namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); @@ -36,6 +40,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, ConfusionMatrix::fromXContent)); + // Regression metrics + namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); + return namedXContent; } @@ -45,6 +52,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Evaluations namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new)); // Evaluation Metrics namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), @@ -55,6 +63,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider Recall::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, + MeanSquaredError.NAME.getPreferredName(), + MeanSquaredError::new)); // Evaluation Metrics Results namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), @@ -63,6 +74,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider ScoreByThresholdResult::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), ConfusionMatrix.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + MeanSquaredError.NAME.getPreferredName(), + MeanSquaredError.Result::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java new file mode 100644 index 00000000000..8dd922b6ac2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +/** + * Calculates the mean squared error between two known numerical fields. + * + * equation: mse = 1/n * Σ(y - y´)^2 + */ +public class MeanSquaredError implements RegressionMetric { + + public static final ParseField NAME = new ParseField("mean_squared_error"); + + private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;"; + private static final String AGG_NAME = "regression_" + NAME.getPreferredName(); + + private static String buildScript(Object...args) { + return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + } + + private static final ObjectParser PARSER = + new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new); + + public static MeanSquaredError fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public MeanSquaredError(StreamInput in) { + + } + + public MeanSquaredError() { + + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public List aggs(String actualField, String predictedField) { + return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))); + } + + @Override + public EvaluationMetricResult evaluate(Aggregations aggs) { + NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME); + return value == null ? null : new Result(value.value()); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // create static hash code from name as there are currently no unique fields per class instance + return Objects.hashCode(NAME.getPreferredName()); + } + + public static class Result implements EvaluationMetricResult { + + private static final String ERROR = "error"; + private final double error; + + public Result(double error) { + this.error = error; + } + + public Result(StreamInput in) throws IOException { + this.error = in.readDouble(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(error); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ERROR, error); + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java new file mode 100644 index 00000000000..455f44ae3c1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of regression results. + */ +public class Regression implements Evaluation { + + public static final ParseField NAME = new ParseField("regression"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS); + } + + public static Regression fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field containing the actual value + * The value of this field is assumed to be numeric + */ + private final String actualField; + + /** + * The field containing the predicted value + * The value of this field is assumed to be numeric + */ + private final String predictedField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public Regression(String actualField, String predictedField, @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD); + this.metrics = initMetrics(metrics); + } + + public Regression(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedField = in.readString(); + this.metrics = in.readNamedWriteableList(RegressionMetric.class); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName)); + return metrics; + } + + private static List defaultMetrics() { + List defaultMetrics = new ArrayList<>(1); + defaultMetrics.add(new MeanSquaredError()); + return defaultMetrics; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public SearchSourceBuilder buildSearch() { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.existsQuery(actualField)) + .filter(QueryBuilders.existsQuery(predictedField)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery); + for (RegressionMetric metric : metrics) { + List aggs = metric.aggs(actualField, predictedField); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + @Override + public void evaluate(SearchResponse searchResponse, ActionListener> listener) { + List results = new ArrayList<>(metrics.size()); + for (RegressionMetric metric : metrics) { + results.add(metric.evaluate(searchResponse.getAggregations())); + } + listener.onResponse(results); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_FIELD.getPreferredName(), predictedField); + + builder.startObject(METRICS.getPreferredName()); + for (RegressionMetric metric : metrics) { + builder.field(metric.getWriteableName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Regression that = (Regression) o; + return Objects.equals(that.actualField, this.actualField) + && Objects.equals(that.predictedField, this.predictedField) + && Objects.equals(that.metrics, this.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedField, metrics); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java new file mode 100644 index 00000000000..1da48e2f305 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.util.List; + +public interface RegressionMetric extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the metric (which may differ to the writeable name) + */ + String getMetricName(); + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual value + * @param predictedField the field that stores the predicted value + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, String predictedField); + + /** + * Calculates the metric result + * @param aggs the aggregations + * @return the metric result + */ + EvaluationMetricResult evaluate(Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java new file mode 100644 index 00000000000..43513514747 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MeanSquaredErrorTests extends AbstractSerializingTestCase { + + @Override + protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException { + return MeanSquaredError.fromXContent(parser); + } + + @Override + protected MeanSquaredError createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return MeanSquaredError::new; + } + + public static MeanSquaredError createRandom() { + return new MeanSquaredError(); + } + + public void testEvaluate() { + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("regression_mean_squared_error", 0.8123), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + MeanSquaredError mse = new MeanSquaredError(); + EvaluationMetricResult result = mse.evaluate(aggs); + + String expected = "{\"error\":0.8123}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + public void testEvaluate_GivenMissingAggs() { + Aggregations aggs = new Aggregations(Collections.singletonList( + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + MeanSquaredError mse = new MeanSquaredError(); + EvaluationMetricResult result = mse.evaluate(aggs); + assertThat(result, is(nullValue())); + } + + private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { + NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); + when(agg.getName()).thenReturn(name); + when(agg.value()).thenReturn(value); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java new file mode 100644 index 00000000000..33ce6e56ff5 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class RegressionTests extends AbstractSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static Regression createRandom() { + List metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom()); + return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics); + } + + @Override + protected Regression doParseInstance(XContentParser parser) throws IOException { + return Regression.fromXContent(parser); + } + + @Override + protected Regression createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Regression::new; + } + + public void testConstructor_GivenEmptyMetrics() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", "bar", Collections.emptyList())); + assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics")); + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 9335783db14..b5efa6736f3 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -72,9 +72,9 @@ integTest.runner { 'ml/evaluate_data_frame/Test given missing index', 'ml/evaluate_data_frame/Test given index does not exist', 'ml/evaluate_data_frame/Test given missing evaluation', - 'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always true', - 'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always false', - 'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics', + 'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always true', + 'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always false', + 'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with empty metrics', 'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field', 'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field', 'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero', @@ -83,6 +83,7 @@ integTest.runner { 'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', + 'ml/evaluate_data_frame/Test regression given evaluation with empty metrics', 'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing job_id', diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index ef844d61f16..d0ed46b0f04 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -8,6 +8,8 @@ setup: "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.0, + "regression_field_act": 10.9, + "regression_field_pred": 10.9, "all_true_field": true, "all_false_field": false } @@ -20,6 +22,8 @@ setup: "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.2, + "regression_field_act": 12.0, + "regression_field_pred": 9.9, "all_true_field": true, "all_false_field": false } @@ -32,6 +36,8 @@ setup: "is_outlier": false, "is_outlier_int": 0, "outlier_score": 0.3, + "regression_field_act": 20.9, + "regression_field_pred": 5.9, "all_true_field": true, "all_false_field": false } @@ -44,6 +50,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.3, + "regression_field_act": 11.9, + "regression_field_pred": 11.9, "all_true_field": true, "all_false_field": false } @@ -56,6 +64,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.4, + "regression_field_act": 42.9, + "regression_field_pred": 42.9, "all_true_field": true, "all_false_field": false } @@ -68,6 +78,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.5, + "regression_field_act": 0.42, + "regression_field_pred": 0.42, "all_true_field": true, "all_false_field": false } @@ -80,6 +92,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.9, + "regression_field_act": 1.1235813, + "regression_field_pred": 1.12358, "all_true_field": true, "all_false_field": false } @@ -92,6 +106,8 @@ setup: "is_outlier": true, "is_outlier_int": 1, "outlier_score": 0.95, + "regression_field_act": -5.20, + "regression_field_pred": -5.1, "all_true_field": true, "all_false_field": false } @@ -109,7 +125,7 @@ setup: indices.refresh: {} --- -"Test binary_soft_classifition auc_roc": +"Test binary_soft_classification auc_roc": - do: ml.evaluate_data_frame: body: > @@ -129,7 +145,7 @@ setup: - is_false: binary_soft_classification.auc_roc.curve --- -"Test binary_soft_classifition auc_roc given actual_field is int": +"Test binary_soft_classification auc_roc given actual_field is int": - do: ml.evaluate_data_frame: body: > @@ -149,7 +165,7 @@ setup: - is_false: binary_soft_classification.auc_roc.curve --- -"Test binary_soft_classifition auc_roc include curve": +"Test binary_soft_classification auc_roc include curve": - do: ml.evaluate_data_frame: body: > @@ -169,7 +185,7 @@ setup: - is_true: binary_soft_classification.auc_roc.curve --- -"Test binary_soft_classifition auc_roc given actual_field is always true": +"Test binary_soft_classification auc_roc given actual_field is always true": - do: catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/ ml.evaluate_data_frame: @@ -188,7 +204,7 @@ setup: } --- -"Test binary_soft_classifition auc_roc given actual_field is always false": +"Test binary_soft_classification auc_roc given actual_field is always false": - do: catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/ ml.evaluate_data_frame: @@ -207,7 +223,7 @@ setup: } --- -"Test binary_soft_classifition precision": +"Test binary_soft_classification precision": - do: ml.evaluate_data_frame: body: > @@ -230,7 +246,7 @@ setup: '0.5': 1.0 --- -"Test binary_soft_classifition recall": +"Test binary_soft_classification recall": - do: ml.evaluate_data_frame: body: > @@ -254,7 +270,7 @@ setup: '0.5': 0.6 --- -"Test binary_soft_classifition confusion_matrix": +"Test binary_soft_classification confusion_matrix": - do: ml.evaluate_data_frame: body: > @@ -290,7 +306,7 @@ setup: fn: 2 --- -"Test binary_soft_classifition default metrics": +"Test binary_soft_classification default metrics": - do: ml.evaluate_data_frame: body: > @@ -356,7 +372,7 @@ setup: } --- -"Test binary_soft_classification given evaluation with emtpy metrics": +"Test binary_soft_classification given evaluation with empty metrics": - do: catch: /\[binary_soft_classification\] must have one or more metrics/ ml.evaluate_data_frame: @@ -518,3 +534,52 @@ setup: } } } +--- +"Test regression given evaluation with empty metrics": + - do: + catch: /\[regression\] must have one or more metrics/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred", + "metrics": { } + } + } + } +--- +"Test regression mean_squared_error": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred", + "metrics": { "mean_squared_error": {} } + } + } + } + + - match: { regression.mean_squared_error.error: 28.67749840974834 } +--- +"Test regression with null metrics": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred" + } + } + } + + - match: { regression.mean_squared_error.error: 28.67749840974834 }