diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index 466d3e1b1a2..0c8094a208f 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -23,7 +23,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -105,8 +105,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider MeanSquaredLogarithmicErrorMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, - new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)), - PseudoHuberMetric::fromXContent), + new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)), + HuberMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), @@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider MeanSquaredLogarithmicErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, - new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)), - PseudoHuberMetric.Result::fromXContent), + new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)), + HuberMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetric.java similarity index 87% rename from client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetric.java rename to client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetric.java index 0db2cd44099..90c2e7ce2b6 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetric.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation.regression; +import org.elasticsearch.client.ml.dataframe.Regression.LossFunction; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -34,30 +35,30 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona /** * Calculates the pseudo Huber loss function. * - * equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1) + * equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1) * where: a = y - y´ * δ - parameter that controls the steepness */ -public class PseudoHuberMetric implements EvaluationMetric { +public class HuberMetric implements EvaluationMetric { - public static final String NAME = "pseudo_huber"; + public static final String NAME = LossFunction.HUBER.toString(); public static final ParseField DELTA = new ParseField("delta"); - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, args -> new HuberMetric((Double) args[0])); static { PARSER.declareDouble(optionalConstructorArg(), DELTA); } - public static PseudoHuberMetric fromXContent(XContentParser parser) { + public static HuberMetric fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } private final Double delta; - public PseudoHuberMetric(@Nullable Double delta) { + public HuberMetric(@Nullable Double delta) { this.delta = delta; } @@ -80,7 +81,7 @@ public class PseudoHuberMetric implements EvaluationMetric { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - PseudoHuberMetric that = (PseudoHuberMetric) o; + HuberMetric that = (HuberMetric) o; return Objects.equals(this.delta, that.delta); } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java index 152e117a5b4..e505693f465 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation.regression; +import org.elasticsearch.client.ml.dataframe.Regression.LossFunction; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -37,7 +38,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru */ public class MeanSquaredErrorMetric implements EvaluationMetric { - public static final String NAME = "mean_squared_error"; + public static final String NAME = LossFunction.MSE.toString(); private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java index 4593fe08799..1c9eccedcf9 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation.regression; +import org.elasticsearch.client.ml.dataframe.Regression.LossFunction; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; @@ -39,7 +40,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona */ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric { - public static final String NAME = "mean_squared_logarithmic_error"; + public static final String NAME = LossFunction.MSLE.toString(); public static final ParseField OFFSET = new ParseField("offset"); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 3d21c66e480..59e2304aa0f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -143,7 +143,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -1889,7 +1889,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { predictedRegression, new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), - new PseudoHuberMetric(1.0), + new HuberMetric(1.0), new RSquaredMetric())); EvaluateDataFrameResponse evaluateDataFrameResponse = @@ -1906,9 +1906,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME)); assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9)); - PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME); - assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME)); - assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9)); + HuberMetric.Result huberResult = evaluateDataFrameResponse.getMetricByName(HuberMetric.NAME); + assertThat(huberResult.getMetricName(), equalTo(HuberMetric.NAME)); + assertThat(huberResult.getValue(), closeTo(0.029669771640929276, 1e-9)); RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME); assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index d52b4ba884e..835012eabb4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -62,7 +62,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -765,7 +765,7 @@ public class RestHighLevelClientTests extends ESTestCase { registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), - registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME), + registeredMetricName(Regression.NAME, HuberMetric.NAME), registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, @@ -782,7 +782,7 @@ public class RestHighLevelClientTests extends ESTestCase { registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), - registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME), + registeredMetricName(Regression.NAME, HuberMetric.NAME), registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 25725f9fc6d..2c0081fcf33 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -162,7 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -3573,7 +3573,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { // Evaluation metrics // <4> new MeanSquaredErrorMetric(), // <5> new MeanSquaredLogarithmicErrorMetric(1.0), // <6> - new PseudoHuberMetric(1.0), // <7> + new HuberMetric(1.0), // <7> new RSquaredMetric()); // <8> // end::evaluate-data-frame-evaluation-regression @@ -3588,8 +3588,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3> double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4> - PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5> - double pseudoHuber = pseudoHuberResult.getValue(); // <6> + HuberMetric.Result huberResult = response.getMetricByName(HuberMetric.NAME); // <5> + double huber = huberResult.getValue(); // <6> RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7> double rSquared = rSquaredResult.getValue(); // <8> @@ -3597,7 +3597,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { assertThat(meanSquaredError, closeTo(0.021, 1e-3)); assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3)); - assertThat(pseudoHuber, closeTo(0.01, 1e-3)); + assertThat(huber, closeTo(0.01, 1e-3)); assertThat(rSquared, closeTo(0.941, 1e-3)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetricResultTests.java similarity index 77% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetricResultTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetricResultTests.java index d2346a0b438..20e6bf1e191 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetricResultTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetricResultTests.java @@ -25,20 +25,20 @@ import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; -public class PseudoHuberMetricResultTests extends AbstractXContentTestCase { +public class HuberMetricResultTests extends AbstractXContentTestCase { - public static PseudoHuberMetric.Result randomResult() { - return new PseudoHuberMetric.Result(randomDouble()); + public static HuberMetric.Result randomResult() { + return new HuberMetric.Result(randomDouble()); } @Override - protected PseudoHuberMetric.Result createTestInstance() { + protected HuberMetric.Result createTestInstance() { return randomResult(); } @Override - protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException { - return PseudoHuberMetric.Result.fromXContent(parser); + protected HuberMetric.Result doParseInstance(XContentParser parser) throws IOException { + return HuberMetric.Result.fromXContent(parser); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetricTests.java similarity index 79% rename from client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetricTests.java rename to client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetricTests.java index 1293f728bfe..3a8af7b692e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/PseudoHuberMetricTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/HuberMetricTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; -public class PseudoHuberMetricTests extends AbstractXContentTestCase { +public class HuberMetricTests extends AbstractXContentTestCase { @Override protected NamedXContentRegistry xContentRegistry() { @@ -33,13 +33,13 @@ public class PseudoHuberMetricTests extends AbstractXContentTestCase { metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance()); } if (randomBoolean()) { - metrics.add(new PseudoHuberMetricTests().createTestInstance()); + metrics.add(new HuberMetricTests().createTestInstance()); } if (randomBoolean()) { metrics.add(new RSquaredMetric()); diff --git a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc index 54e9907b835..ffd5775289d 100644 --- a/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc @@ -125,15 +125,15 @@ which outputs a prediction of values. (Optional, object) Specifies the metrics that are used for the evaluation. Available metrics: - `mean_squared_error`::: + `mse`::: (Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value. For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article]. - `mean_squared_logarithmic_error`::: + `msle`::: (Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual (`ground truth`) value. - `pseudo_huber`::: + `huber`::: (Optional, object) Pseudo Huber loss function. For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article]. @@ -280,7 +280,7 @@ POST _ml/data_frame/_evaluate "predicted_field": "ml.price_prediction", <4> "metrics": { "r_squared": {}, - "mean_squared_error": {} + "mse": {} } } } @@ -317,7 +317,7 @@ POST _ml/data_frame/_evaluate "predicted_field": "ml.G3_prediction", <3> "metrics": { "r_squared": {}, - "mean_squared_error": {} + "mse": {} } } } @@ -356,7 +356,7 @@ POST _ml/data_frame/_evaluate "predicted_field": "ml.G3_prediction", <3> "metrics": { "r_squared": {}, - "mean_squared_error": {} + "mse": {} } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index d22c6a32f0f..300b0b968ee 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -12,9 +12,9 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; @@ -101,8 +101,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)), MeanSquaredLogarithmicError::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, - new ParseField(registeredMetricName(Regression.NAME, PseudoHuber.NAME)), - PseudoHuber::fromXContent), + new ParseField(registeredMetricName(Regression.NAME, Huber.NAME)), + Huber::fromXContent), new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)), RSquared::fromXContent) @@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), MeanSquaredLogarithmicError::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, - registeredMetricName(Regression.NAME, PseudoHuber.NAME), - PseudoHuber::new), + registeredMetricName(Regression.NAME, Huber.NAME), + Huber::new), new NamedWriteableRegistry.Entry(EvaluationMetric.class, registeredMetricName(Regression.NAME, RSquared.NAME), RSquared::new), @@ -193,8 +193,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), MeanSquaredLogarithmicError.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, - registeredMetricName(Regression.NAME, PseudoHuber.NAME), - PseudoHuber.Result::new), + registeredMetricName(Regression.NAME, Huber.NAME), + Huber.Result::new), new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, registeredMetricName(Regression.NAME, RSquared.NAME), RSquared.Result::new) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/PseudoHuber.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java similarity index 91% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/PseudoHuber.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java index 8c8ed3b31c6..7be8946b293 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/PseudoHuber.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Huber.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -37,13 +38,13 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN /** * Calculates the pseudo Huber loss function. * - * equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1) + * equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1) * where: a = y - y´ * δ - parameter that controls the steepness */ -public class PseudoHuber implements EvaluationMetric { +public class Huber implements EvaluationMetric { - public static final ParseField NAME = new ParseField("pseudo_huber"); + public static final ParseField NAME = new ParseField(LossFunction.HUBER.toString()); public static final ParseField DELTA = new ParseField("delta"); private static final double DEFAULT_DELTA = 1.0; @@ -58,25 +59,25 @@ public class PseudoHuber implements EvaluationMetric { return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); } - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new PseudoHuber((Double) args[0])); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new Huber((Double) args[0])); static { PARSER.declareDouble(optionalConstructorArg(), DELTA); } - public static PseudoHuber fromXContent(XContentParser parser) { + public static Huber fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } private final double delta; private EvaluationMetricResult result; - public PseudoHuber(StreamInput in) throws IOException { + public Huber(StreamInput in) throws IOException { this.delta = in.readDouble(); } - public PseudoHuber(@Nullable Double delta) { + public Huber(@Nullable Double delta) { this.delta = delta != null ? delta : DEFAULT_DELTA; } @@ -130,7 +131,7 @@ public class PseudoHuber implements EvaluationMetric { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - PseudoHuber that = (PseudoHuber) o; + Huber that = (Huber) o; return this.delta == that.delta; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index 6f740b0e012..2637109646a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -40,7 +41,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN */ public class MeanSquaredError implements EvaluationMetric { - public static final ParseField NAME = new ParseField("mean_squared_error"); + public static final ParseField NAME = new ParseField(LossFunction.MSE.toString()); private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;" + diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java index d6b0d0bae2f..af2af28ce04 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; @@ -42,7 +43,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN */ public class MeanSquaredLogarithmicError implements EvaluationMetric { - public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error"); + public static final ParseField NAME = new ParseField(LossFunction.MSLE.toString()); public static final ParseField OFFSET = new ParseField("offset"); private static final double DEFAULT_OFFSET = 1.0; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java index a68f6006221..02767366c67 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java @@ -17,7 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Preci import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import java.util.Arrays; @@ -41,7 +41,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin MulticlassConfusionMatrixResultTests.createRandom(), new MeanSquaredError.Result(randomDouble()), new MeanSquaredLogarithmicError.Result(randomDouble()), - new PseudoHuber.Result(randomDouble()), + new Huber.Result(randomDouble()), new RSquared.Result(randomDouble())); return new Response(evaluationName, randomSubsetOf(metrics)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/PseudoHuberTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/HuberTests.java similarity index 60% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/PseudoHuberTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/HuberTests.java index dee1ea61c47..91bb6bc510e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/PseudoHuberTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/HuberTests.java @@ -19,37 +19,37 @@ import java.util.Collections; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.hamcrest.Matchers.equalTo; -public class PseudoHuberTests extends AbstractSerializingTestCase { +public class HuberTests extends AbstractSerializingTestCase { @Override - protected PseudoHuber doParseInstance(XContentParser parser) throws IOException { - return PseudoHuber.fromXContent(parser); + protected Huber doParseInstance(XContentParser parser) throws IOException { + return Huber.fromXContent(parser); } @Override - protected PseudoHuber createTestInstance() { + protected Huber createTestInstance() { return createRandom(); } @Override - protected Writeable.Reader instanceReader() { - return PseudoHuber::new; + protected Writeable.Reader instanceReader() { + return Huber::new; } - public static PseudoHuber createRandom() { - return new PseudoHuber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null); + public static Huber createRandom() { + return new Huber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null); } public void testEvaluate() { Aggregations aggs = new Aggregations(Arrays.asList( - mockSingleValue("regression_pseudo_huber", 0.8123), + mockSingleValue("regression_huber", 0.8123), mockSingleValue("some_other_single_metric_agg", 0.2377) )); - PseudoHuber pseudoHuber = new PseudoHuber((Double) null); - pseudoHuber.process(aggs); + Huber huber = new Huber((Double) null); + huber.process(aggs); - EvaluationMetricResult result = pseudoHuber.getResult().get(); + EvaluationMetricResult result = huber.getResult().get(); String expected = "{\"value\":0.8123}"; assertThat(Strings.toString(result), equalTo(expected)); } @@ -59,10 +59,10 @@ public class PseudoHuberTests extends AbstractSerializingTestCase { mockSingleValue("some_other_single_metric_agg", 0.2377) )); - PseudoHuber pseudoHuber = new PseudoHuber((Double) null); - pseudoHuber.process(aggs); + Huber huber = new Huber((Double) null); + huber.process(aggs); - EvaluationMetricResult result = pseudoHuber.getResult().get(); - assertThat(result, equalTo(new PseudoHuber.Result(0.0))); + EvaluationMetricResult result = huber.getResult().get(); + assertThat(result, equalTo(new Huber.Result(0.0))); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java index a723bc3fbce..3d8fbd2374b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -42,7 +42,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase @@ -881,14 +881,14 @@ setup: "regression": { "actual_field": "regression_field_act", "predicted_field": "regression_field_pred", - "metrics": { "pseudo_huber": { "delta": 2.0 } } + "metrics": { "huber": { "delta": 2.0 } } } } } - - match: { regression.pseudo_huber.value: 3.5088110471730145 } - - is_false: regression.mean_squared_logarithmic_error.value - - is_false: regression.mean_squared_error.value + - match: { regression.huber.value: 3.5088110471730145 } + - is_false: regression.msle.value + - is_false: regression.mse.value - is_false: regression.r_squared.value --- "Test regression r_squared": @@ -906,9 +906,9 @@ setup: } } - match: { regression.r_squared.value: 0.8551031778603486 } - - is_false: regression.mean_squared_error - - is_false: regression.mean_squared_logarithmic_error.value - - is_false: regression.pseudo_huber.value + - is_false: regression.mse + - is_false: regression.msle.value + - is_false: regression.huber.value --- "Test regression with null metrics": @@ -925,9 +925,10 @@ setup: } } - - match: { regression.mean_squared_error.value: 28.67749840974834 } + - match: { regression.mse.value: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } - - is_false: regression.mean_squared_logarithmic_error.value + - is_false: regression.msle.value + - is_false: regression.huber.value --- "Test regression given missing actual_field": - do: