mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-25 17:38:44 +00:00
[7.x] Implement MSLE (MeanSquaredLogarithmicError) evaluation metric for regression analysis (#58684) (#58731)
This commit is contained in:
parent
b885cbff1a
commit
9ea9b7bd3b
@ -22,6 +22,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyM
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
@ -97,6 +98,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
|
||||
MeanSquaredErrorMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
|
||||
MeanSquaredLogarithmicErrorMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
||||
@ -140,6 +145,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
|
||||
MeanSquaredErrorMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
|
||||
MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
||||
|
@ -40,16 +40,13 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "mean_squared_error";
|
||||
|
||||
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER =
|
||||
new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new);
|
||||
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new);
|
||||
|
||||
public static MeanSquaredErrorMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public MeanSquaredErrorMetric() {
|
||||
|
||||
}
|
||||
public MeanSquaredErrorMetric() {}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
|
@ -0,0 +1,142 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
/**
|
||||
* Calculates the mean squared error between two known numerical fields.
|
||||
*
|
||||
* equation: msle = 1/n * Σ(log(y + offset) - log(y´ + offset))^2
|
||||
* where offset is used to make sure the argument to log function is always positive
|
||||
*/
|
||||
public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "mean_squared_logarithmic_error";
|
||||
|
||||
public static final ParseField OFFSET = new ParseField("offset");
|
||||
|
||||
private static final ConstructingObjectParser<MeanSquaredLogarithmicErrorMetric, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME, true, args -> new MeanSquaredLogarithmicErrorMetric((Double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(optionalConstructorArg(), OFFSET);
|
||||
}
|
||||
|
||||
public static MeanSquaredLogarithmicErrorMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final Double offset;
|
||||
|
||||
public MeanSquaredLogarithmicErrorMetric(@Nullable Double offset) {
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (offset != null) {
|
||||
builder.field(OFFSET.getPreferredName(), offset);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MeanSquaredLogarithmicErrorMetric that = (MeanSquaredLogarithmicErrorMetric) o;
|
||||
return Objects.equals(this.offset, that.offset);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(offset);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static final ParseField ERROR = new ParseField("error");
|
||||
private final double error;
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), ERROR);
|
||||
}
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR.getPreferredName(), error);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.error, this.error);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(error);
|
||||
}
|
||||
}
|
||||
}
|
@ -42,16 +42,13 @@ public class RSquaredMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "r_squared";
|
||||
|
||||
private static final ObjectParser<RSquaredMetric, Void> PARSER =
|
||||
new ObjectParser<>("r_squared", true, RSquaredMetric::new);
|
||||
private static final ObjectParser<RSquaredMetric, Void> PARSER = new ObjectParser<>(NAME, true, RSquaredMetric::new);
|
||||
|
||||
public static RSquaredMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public RSquaredMetric() {
|
||||
|
||||
}
|
||||
public RSquaredMetric() {}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
|
@ -142,6 +142,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyM
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
@ -1882,17 +1883,25 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||
new EvaluateDataFrameRequest(
|
||||
regressionIndex,
|
||||
null,
|
||||
new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
||||
new Regression(
|
||||
actualRegression,
|
||||
predictedRegression,
|
||||
new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), new RSquaredMetric()));
|
||||
|
||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(3));
|
||||
|
||||
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
|
||||
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
|
||||
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
|
||||
|
||||
MeanSquaredLogarithmicErrorMetric.Result msleResult =
|
||||
evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME);
|
||||
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
|
||||
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
|
||||
|
||||
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
|
||||
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
|
||||
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
|
||||
|
@ -61,6 +61,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyM
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
@ -701,7 +702,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(64, namedXContents.size());
|
||||
assertEquals(66, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
@ -748,7 +749,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(
|
||||
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
|
||||
@ -762,8 +763,9 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(
|
||||
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
|
||||
@ -777,6 +779,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||
Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
|
||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
|
||||
|
@ -161,6 +161,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
@ -3570,7 +3571,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
"predicted_value", // <3>
|
||||
// Evaluation metrics // <4>
|
||||
new MeanSquaredErrorMetric(), // <5>
|
||||
new RSquaredMetric()); // <6>
|
||||
new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
|
||||
new RSquaredMetric()); // <7>
|
||||
// end::evaluate-data-frame-evaluation-regression
|
||||
|
||||
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
|
||||
@ -3580,11 +3582,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
|
||||
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
|
||||
|
||||
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3>
|
||||
double rSquared = rSquaredResult.getValue(); // <4>
|
||||
MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult =
|
||||
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
|
||||
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
|
||||
|
||||
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <5>
|
||||
double rSquared = rSquaredResult.getValue(); // <6>
|
||||
// end::evaluate-data-frame-results-regression
|
||||
|
||||
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
|
||||
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
|
||||
assertThat(rSquared, closeTo(0.941, 1e-3));
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class MeanSquaredLogarithmicErrorMetricResultTests extends AbstractXContentTestCase<MeanSquaredLogarithmicErrorMetric.Result> {
|
||||
|
||||
public static MeanSquaredLogarithmicErrorMetric.Result randomResult() {
|
||||
return new MeanSquaredLogarithmicErrorMetric.Result(randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredLogarithmicErrorMetric.Result createTestInstance() {
|
||||
return randomResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredLogarithmicErrorMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return MeanSquaredLogarithmicErrorMetric.Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class MeanSquaredLogarithmicErrorMetricTests extends AbstractXContentTestCase<MeanSquaredLogarithmicErrorMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredLogarithmicErrorMetric createTestInstance() {
|
||||
return new MeanSquaredLogarithmicErrorMetric(randomBoolean() ? randomDouble() : null);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredLogarithmicErrorMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return MeanSquaredLogarithmicErrorMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
@ -41,6 +41,9 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new MeanSquaredErrorMetric());
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new RSquaredMetric());
|
||||
}
|
||||
|
@ -68,7 +68,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-regression]
|
||||
<3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example.
|
||||
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
|
||||
<5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error]
|
||||
<6> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
|
||||
<6> Mean squared logarithmic error
|
||||
<7> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
|
||||
|
||||
include::../execution.asciidoc[]
|
||||
|
||||
@ -123,5 +124,7 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
|
||||
|
||||
<1> Fetching mean squared error metric by name
|
||||
<2> Fetching the actual mean squared error value
|
||||
<3> Fetching R squared metric by name
|
||||
<4> Fetching the actual R squared value
|
||||
<3> Fetching mean squared logarithmic error metric by name
|
||||
<4> Fetching the actual mean squared logarithmic error value
|
||||
<5> Fetching R squared metric by name
|
||||
<6> Fetching the actual R squared value
|
||||
|
@ -129,6 +129,10 @@ which outputs a prediction of values.
|
||||
(Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value.
|
||||
For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article].
|
||||
|
||||
`mean_squared_logarithmic_error`:::
|
||||
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
|
||||
(`ground truth`) value.
|
||||
|
||||
`r_squared`:::
|
||||
(Optional, object) Proportion of the variance in the dependent variable that is predictable from the independent variables.
|
||||
For more information, read https://en.wikipedia.org/wiki/Coefficient_of_determination[this wiki article].
|
||||
|
@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accur
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
@ -95,6 +96,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)),
|
||||
MeanSquaredError::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
|
||||
MeanSquaredLogarithmicError::fromXContent),
|
||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
|
||||
RSquared::fromXContent)
|
||||
@ -144,6 +148,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
||||
MeanSquaredError::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
|
||||
MeanSquaredLogarithmicError::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||
registeredMetricName(Regression.NAME, RSquared.NAME),
|
||||
RSquared::new),
|
||||
@ -175,6 +182,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
||||
MeanSquaredError.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
|
||||
MeanSquaredLogarithmicError.Result::new),
|
||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
registeredMetricName(Regression.NAME, RSquared.NAME),
|
||||
RSquared.Result::new)
|
||||
|
@ -42,7 +42,9 @@ public class MeanSquaredError implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("mean_squared_error");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
|
||||
private static final String PAINLESS_TEMPLATE =
|
||||
"def diff = doc[''{0}''].value - doc[''{1}''].value;" +
|
||||
"return diff * diff;";
|
||||
private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
|
||||
|
||||
private static String buildScript(Object...args) {
|
||||
@ -143,6 +145,10 @@ public class MeanSquaredError implements EvaluationMetric {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(error);
|
||||
|
@ -0,0 +1,195 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
||||
|
||||
/**
|
||||
* Calculates the mean squared error between two known numerical fields.
|
||||
*
|
||||
* equation: msle = 1/n * Σ(log(y + offset) - log(y´ + offset))^2
|
||||
* where offset is used to make sure the argument to log function is always positive
|
||||
*/
|
||||
public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error");
|
||||
|
||||
public static final ParseField OFFSET = new ParseField("offset");
|
||||
private static final double DEFAULT_OFFSET = 1.0;
|
||||
|
||||
private static final String PAINLESS_TEMPLATE =
|
||||
"def offset = {2};" +
|
||||
"def diff = Math.log(doc[''{0}''].value + offset) - Math.log(doc[''{1}''].value + offset);" +
|
||||
"return diff * diff;";
|
||||
private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
|
||||
|
||||
private static String buildScript(Object...args) {
|
||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<MeanSquaredLogarithmicError, Void> PARSER =
|
||||
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MeanSquaredLogarithmicError((Double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(optionalConstructorArg(), OFFSET);
|
||||
}
|
||||
|
||||
public static MeanSquaredLogarithmicError fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private final double offset;
|
||||
private EvaluationMetricResult result;
|
||||
|
||||
public MeanSquaredLogarithmicError(StreamInput in) throws IOException {
|
||||
this.offset = in.readDouble();
|
||||
}
|
||||
|
||||
public MeanSquaredLogarithmicError(@Nullable Double offset) {
|
||||
this.offset = offset != null ? offset : DEFAULT_OFFSET;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
|
||||
String actualField,
|
||||
String predictedField) {
|
||||
if (result != null) {
|
||||
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
return Tuple.tuple(
|
||||
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, offset)))),
|
||||
Collections.emptyList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(Aggregations aggs) {
|
||||
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
|
||||
result = value == null ? new Result(0.0) : new Result(value.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<EvaluationMetricResult> getResult() {
|
||||
return Optional.ofNullable(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Regression.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(offset);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(OFFSET.getPreferredName(), offset);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
MeanSquaredLogarithmicError that = (MeanSquaredLogarithmicError) o;
|
||||
return this.offset == that.offset;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Double.hashCode(offset);
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String ERROR = "error";
|
||||
private final double error;
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.error = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return registeredMetricName(Regression.NAME, NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(error);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR, error);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result other = (Result)o;
|
||||
return error == other.error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(error);
|
||||
}
|
||||
}
|
||||
}
|
@ -47,7 +47,9 @@ public class RSquared implements EvaluationMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("r_squared");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
|
||||
private static final String PAINLESS_TEMPLATE =
|
||||
"def diff = doc[''{0}''].value - doc[''{1}''].value;" +
|
||||
"return diff * diff;";
|
||||
private static final String SS_RES = "residual_sum_of_squares";
|
||||
|
||||
private static String buildScript(Object... args) {
|
||||
@ -158,6 +160,10 @@ public class RSquared implements EvaluationMetric {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(value);
|
||||
|
@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Multi
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -38,6 +39,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
|
||||
RecallResultTests.createRandom(),
|
||||
MulticlassConfusionMatrixResultTests.createRandom(),
|
||||
new MeanSquaredError.Result(randomDouble()),
|
||||
new MeanSquaredLogarithmicError.Result(randomDouble()),
|
||||
new RSquared.Result(randomDouble()));
|
||||
return new Response(evaluationName, randomSubsetOf(metrics));
|
||||
}
|
||||
|
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCase<MeanSquaredLogarithmicError> {
|
||||
|
||||
@Override
|
||||
protected MeanSquaredLogarithmicError doParseInstance(XContentParser parser) throws IOException {
|
||||
return MeanSquaredLogarithmicError.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected MeanSquaredLogarithmicError createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<MeanSquaredLogarithmicError> instanceReader() {
|
||||
return MeanSquaredLogarithmicError::new;
|
||||
}
|
||||
|
||||
public static MeanSquaredLogarithmicError createRandom() {
|
||||
return new MeanSquaredLogarithmicError(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
mockSingleValue("regression_mean_squared_logarithmic_error", 0.8123),
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError((Double) null);
|
||||
msle.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = msle.getResult().get();
|
||||
String expected = "{\"error\":0.8123}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingAggs() {
|
||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError((Double) null);
|
||||
msle.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = msle.getResult().get();
|
||||
assertThat(result, equalTo(new MeanSquaredLogarithmicError.Result(0.0)));
|
||||
}
|
||||
}
|
@ -0,0 +1,143 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
|
||||
public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String HOUSES_DATA_INDEX = "test-evaluate-houses-index";
|
||||
|
||||
private static final String PRICE_FIELD = "price";
|
||||
private static final String PRICE_PREDICTION_FIELD = "price_prediction";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
createHousesIndex(HOUSES_DATA_INDEX);
|
||||
indexHousesData(HOUSES_DATA_INDEX);
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void testEvaluate_DefaultMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, null));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||
contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_AllMetrics() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
HOUSES_DATA_INDEX,
|
||||
new Regression(
|
||||
PRICE_FIELD,
|
||||
PRICE_PREDICTION_FIELD,
|
||||
Arrays.asList(new MeanSquaredError(), new MeanSquaredLogarithmicError((Double) null), new RSquared())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
||||
assertThat(
|
||||
evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
|
||||
contains(
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredLogarithmicError.NAME.getPreferredName(),
|
||||
RSquared.NAME.getPreferredName()));
|
||||
}
|
||||
|
||||
public void testEvaluate_MeanSquaredError() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
HOUSES_DATA_INDEX,
|
||||
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new MeanSquaredError())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName()));
|
||||
assertThat(mseResult.getError(), equalTo(1000000.0));
|
||||
}
|
||||
|
||||
public void testEvaluate_MeanSquaredLogarithmicError() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
HOUSES_DATA_INDEX,
|
||||
new Regression(
|
||||
PRICE_FIELD,
|
||||
PRICE_PREDICTION_FIELD,
|
||||
Collections.singletonList(new MeanSquaredLogarithmicError((Double) null))));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
|
||||
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1001), 2), 10E-6));
|
||||
}
|
||||
|
||||
public void testEvaluate_RSquared() {
|
||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||
evaluateDataFrame(
|
||||
HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new RSquared())));
|
||||
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||
|
||||
RSquared.Result rSquaredResult = (RSquared.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(rSquaredResult.getMetricName(), equalTo(RSquared.NAME.getPreferredName()));
|
||||
assertThat(rSquaredResult.getValue(), equalTo(0.0));
|
||||
}
|
||||
|
||||
private static void createHousesIndex(String indexName) {
|
||||
client().admin().indices().prepareCreate(indexName)
|
||||
.addMapping("_doc",
|
||||
PRICE_FIELD, "type=double",
|
||||
PRICE_PREDICTION_FIELD, "type=double")
|
||||
.get();
|
||||
}
|
||||
|
||||
private static void indexHousesData(String indexName) {
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 100; i++) {
|
||||
bulkRequestBuilder.add(
|
||||
new IndexRequest(indexName)
|
||||
.source(
|
||||
PRICE_FIELD, 1000,
|
||||
PRICE_PREDICTION_FIELD, 0));
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
}
|
@ -847,6 +847,26 @@ setup:
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- is_false: regression.mean_squared_logarithmic_error.value
|
||||
- is_false: regression.r_squared.value
|
||||
---
|
||||
"Test regression mean_squared_logarithmic_error":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "regression_field_act",
|
||||
"predicted_field": "regression_field_pred",
|
||||
"metrics": { "mean_squared_logarithmic_error": { "offset": 6.0 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 }
|
||||
- is_false: regression.mean_squared_error.value
|
||||
- is_false: regression.r_squared.value
|
||||
---
|
||||
"Test regression r_squared":
|
||||
@ -865,6 +885,7 @@ setup:
|
||||
}
|
||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||
- is_false: regression.mean_squared_error
|
||||
- is_false: regression.mean_squared_logarithmic_error.value
|
||||
---
|
||||
"Test regression with null metrics":
|
||||
- do:
|
||||
@ -882,6 +903,7 @@ setup:
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||
- is_false: regression.mean_squared_logarithmic_error.value
|
||||
---
|
||||
"Test regression given missing actual_field":
|
||||
- do:
|
||||
|
Loading…
x
Reference in New Issue
Block a user