Rename regression evaluation metrics to make the names consistent with loss functions (#58887) (#58927)

This commit is contained in:
Przemysław Witek 2020-07-02 17:35:55 +02:00 committed by GitHub
parent 6aa669c8bb
commit 751e84e4c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 115 additions and 108 deletions

View File

@ -23,7 +23,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -105,8 +105,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
MeanSquaredLogarithmicErrorMetric::fromXContent), MeanSquaredLogarithmicErrorMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)),
PseudoHuberMetric::fromXContent), HuberMetric::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.class, EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
MeanSquaredLogarithmicErrorMetric.Result::fromXContent), MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)),
PseudoHuberMetric.Result::fromXContent), HuberMetric.Result::fromXContent),
new NamedXContentRegistry.Entry( new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class, EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)), new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),

View File

@ -18,6 +18,7 @@
*/ */
package org.elasticsearch.client.ml.dataframe.evaluation.regression; package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
@ -34,30 +35,30 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
/** /**
* Calculates the pseudo Huber loss function. * Calculates the pseudo Huber loss function.
* *
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1) * equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´ * where: a = y - y´
* δ - parameter that controls the steepness * δ - parameter that controls the steepness
*/ */
public class PseudoHuberMetric implements EvaluationMetric { public class HuberMetric implements EvaluationMetric {
public static final String NAME = "pseudo_huber"; public static final String NAME = LossFunction.HUBER.toString();
public static final ParseField DELTA = new ParseField("delta"); public static final ParseField DELTA = new ParseField("delta");
private static final ConstructingObjectParser<PseudoHuberMetric, Void> PARSER = private static final ConstructingObjectParser<HuberMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0])); new ConstructingObjectParser<>(NAME, true, args -> new HuberMetric((Double) args[0]));
static { static {
PARSER.declareDouble(optionalConstructorArg(), DELTA); PARSER.declareDouble(optionalConstructorArg(), DELTA);
} }
public static PseudoHuberMetric fromXContent(XContentParser parser) { public static HuberMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private final Double delta; private final Double delta;
public PseudoHuberMetric(@Nullable Double delta) { public HuberMetric(@Nullable Double delta) {
this.delta = delta; this.delta = delta;
} }
@ -80,7 +81,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
PseudoHuberMetric that = (PseudoHuberMetric) o; HuberMetric that = (HuberMetric) o;
return Objects.equals(this.delta, that.delta); return Objects.equals(this.delta, that.delta);
} }

View File

@ -18,6 +18,7 @@
*/ */
package org.elasticsearch.client.ml.dataframe.evaluation.regression; package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -37,7 +38,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
*/ */
public class MeanSquaredErrorMetric implements EvaluationMetric { public class MeanSquaredErrorMetric implements EvaluationMetric {
public static final String NAME = "mean_squared_error"; public static final String NAME = LossFunction.MSE.toString();
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new); private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new);

View File

@ -18,6 +18,7 @@
*/ */
package org.elasticsearch.client.ml.dataframe.evaluation.regression; package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
@ -39,7 +40,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
*/ */
public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric { public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
public static final String NAME = "mean_squared_logarithmic_error"; public static final String NAME = LossFunction.MSLE.toString();
public static final ParseField OFFSET = new ParseField("offset"); public static final ParseField OFFSET = new ParseField("offset");

View File

@ -143,7 +143,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -1889,7 +1889,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
predictedRegression, predictedRegression,
new MeanSquaredErrorMetric(), new MeanSquaredErrorMetric(),
new MeanSquaredLogarithmicErrorMetric(1.0), new MeanSquaredLogarithmicErrorMetric(1.0),
new PseudoHuberMetric(1.0), new HuberMetric(1.0),
new RSquaredMetric())); new RSquaredMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse = EvaluateDataFrameResponse evaluateDataFrameResponse =
@ -1906,9 +1906,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME)); assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9)); assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME); HuberMetric.Result huberResult = evaluateDataFrameResponse.getMetricByName(HuberMetric.NAME);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME)); assertThat(huberResult.getMetricName(), equalTo(HuberMetric.NAME));
assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9)); assertThat(huberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME); RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME)); assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));

View File

@ -62,7 +62,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -765,7 +765,7 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME), registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names, assertThat(names,
@ -782,7 +782,7 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME), registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME), registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME))); registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));

View File

@ -162,7 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@ -3573,7 +3573,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// Evaluation metrics // <4> // Evaluation metrics // <4>
new MeanSquaredErrorMetric(), // <5> new MeanSquaredErrorMetric(), // <5>
new MeanSquaredLogarithmicErrorMetric(1.0), // <6> new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
new PseudoHuberMetric(1.0), // <7> new HuberMetric(1.0), // <7>
new RSquaredMetric()); // <8> new RSquaredMetric()); // <8>
// end::evaluate-data-frame-evaluation-regression // end::evaluate-data-frame-evaluation-regression
@ -3588,8 +3588,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3> response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4> double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5> HuberMetric.Result huberResult = response.getMetricByName(HuberMetric.NAME); // <5>
double pseudoHuber = pseudoHuberResult.getValue(); // <6> double huber = huberResult.getValue(); // <6>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7> RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
double rSquared = rSquaredResult.getValue(); // <8> double rSquared = rSquaredResult.getValue(); // <8>
@ -3597,7 +3597,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
assertThat(meanSquaredError, closeTo(0.021, 1e-3)); assertThat(meanSquaredError, closeTo(0.021, 1e-3));
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3)); assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
assertThat(pseudoHuber, closeTo(0.01, 1e-3)); assertThat(huber, closeTo(0.01, 1e-3));
assertThat(rSquared, closeTo(0.941, 1e-3)); assertThat(rSquared, closeTo(0.941, 1e-3));
} }
} }

View File

@ -25,20 +25,20 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException; import java.io.IOException;
public class PseudoHuberMetricResultTests extends AbstractXContentTestCase<PseudoHuberMetric.Result> { public class HuberMetricResultTests extends AbstractXContentTestCase<HuberMetric.Result> {
public static PseudoHuberMetric.Result randomResult() { public static HuberMetric.Result randomResult() {
return new PseudoHuberMetric.Result(randomDouble()); return new HuberMetric.Result(randomDouble());
} }
@Override @Override
protected PseudoHuberMetric.Result createTestInstance() { protected HuberMetric.Result createTestInstance() {
return randomResult(); return randomResult();
} }
@Override @Override
protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException { protected HuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.Result.fromXContent(parser); return HuberMetric.Result.fromXContent(parser);
} }
@Override @Override

View File

@ -25,7 +25,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException; import java.io.IOException;
public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuberMetric> { public class HuberMetricTests extends AbstractXContentTestCase<HuberMetric> {
@Override @Override
protected NamedXContentRegistry xContentRegistry() { protected NamedXContentRegistry xContentRegistry() {
@ -33,13 +33,13 @@ public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuber
} }
@Override @Override
protected PseudoHuberMetric createTestInstance() { protected HuberMetric createTestInstance() {
return new PseudoHuberMetric(randomBoolean() ? randomDouble() : null); return new HuberMetric(randomBoolean() ? randomDouble() : null);
} }
@Override @Override
protected PseudoHuberMetric doParseInstance(XContentParser parser) throws IOException { protected HuberMetric doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.fromXContent(parser); return HuberMetric.fromXContent(parser);
} }
@Override @Override

View File

@ -45,7 +45,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance()); metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
} }
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(new PseudoHuberMetricTests().createTestInstance()); metrics.add(new HuberMetricTests().createTestInstance());
} }
if (randomBoolean()) { if (randomBoolean()) {
metrics.add(new RSquaredMetric()); metrics.add(new RSquaredMetric());

View File

@ -125,15 +125,15 @@ which outputs a prediction of values.
(Optional, object) Specifies the metrics that are used for the evaluation. (Optional, object) Specifies the metrics that are used for the evaluation.
Available metrics: Available metrics:
`mean_squared_error`::: `mse`:::
(Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value. (Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value.
For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article]. For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article].
`mean_squared_logarithmic_error`::: `msle`:::
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual (Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
(`ground truth`) value. (`ground truth`) value.
`pseudo_huber`::: `huber`:::
(Optional, object) Pseudo Huber loss function. (Optional, object) Pseudo Huber loss function.
For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article]. For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article].
@ -280,7 +280,7 @@ POST _ml/data_frame/_evaluate
"predicted_field": "ml.price_prediction", <4> "predicted_field": "ml.price_prediction", <4>
"metrics": { "metrics": {
"r_squared": {}, "r_squared": {},
"mean_squared_error": {} "mse": {}
} }
} }
} }
@ -317,7 +317,7 @@ POST _ml/data_frame/_evaluate
"predicted_field": "ml.G3_prediction", <3> "predicted_field": "ml.G3_prediction", <3>
"metrics": { "metrics": {
"r_squared": {}, "r_squared": {},
"mean_squared_error": {} "mse": {}
} }
} }
} }
@ -356,7 +356,7 @@ POST _ml/data_frame/_evaluate
"predicted_field": "ml.G3_prediction", <3> "predicted_field": "ml.G3_prediction", <3>
"metrics": { "metrics": {
"r_squared": {}, "r_squared": {},
"mean_squared_error": {} "mse": {}
} }
} }
} }

View File

@ -12,9 +12,9 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -101,8 +101,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)), new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
MeanSquaredLogarithmicError::fromXContent), MeanSquaredLogarithmicError::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuber.NAME)), new ParseField(registeredMetricName(Regression.NAME, Huber.NAME)),
PseudoHuber::fromXContent), Huber::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class, new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)), new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
RSquared::fromXContent) RSquared::fromXContent)
@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError::new), MeanSquaredLogarithmicError::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class, new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME), registeredMetricName(Regression.NAME, Huber.NAME),
PseudoHuber::new), Huber::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class, new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, RSquared.NAME), registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared::new), RSquared::new),
@ -193,8 +193,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME), registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError.Result::new), MeanSquaredLogarithmicError.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME), registeredMetricName(Regression.NAME, Huber.NAME),
PseudoHuber.Result::new), Huber.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, RSquared.NAME), registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared.Result::new) RSquared.Result::new)

View File

@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -37,13 +38,13 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
/** /**
* Calculates the pseudo Huber loss function. * Calculates the pseudo Huber loss function.
* *
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1) * equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´ * where: a = y - y´
* δ - parameter that controls the steepness * δ - parameter that controls the steepness
*/ */
public class PseudoHuber implements EvaluationMetric { public class Huber implements EvaluationMetric {
public static final ParseField NAME = new ParseField("pseudo_huber"); public static final ParseField NAME = new ParseField(LossFunction.HUBER.toString());
public static final ParseField DELTA = new ParseField("delta"); public static final ParseField DELTA = new ParseField("delta");
private static final double DEFAULT_DELTA = 1.0; private static final double DEFAULT_DELTA = 1.0;
@ -58,25 +59,25 @@ public class PseudoHuber implements EvaluationMetric {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
} }
private static final ConstructingObjectParser<PseudoHuber, Void> PARSER = private static final ConstructingObjectParser<Huber, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new PseudoHuber((Double) args[0])); new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new Huber((Double) args[0]));
static { static {
PARSER.declareDouble(optionalConstructorArg(), DELTA); PARSER.declareDouble(optionalConstructorArg(), DELTA);
} }
public static PseudoHuber fromXContent(XContentParser parser) { public static Huber fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private final double delta; private final double delta;
private EvaluationMetricResult result; private EvaluationMetricResult result;
public PseudoHuber(StreamInput in) throws IOException { public Huber(StreamInput in) throws IOException {
this.delta = in.readDouble(); this.delta = in.readDouble();
} }
public PseudoHuber(@Nullable Double delta) { public Huber(@Nullable Double delta) {
this.delta = delta != null ? delta : DEFAULT_DELTA; this.delta = delta != null ? delta : DEFAULT_DELTA;
} }
@ -130,7 +131,7 @@ public class PseudoHuber implements EvaluationMetric {
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
PseudoHuber that = (PseudoHuber) o; Huber that = (Huber) o;
return this.delta == that.delta; return this.delta == that.delta;
} }

View File

@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -40,7 +41,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
*/ */
public class MeanSquaredError implements EvaluationMetric { public class MeanSquaredError implements EvaluationMetric {
public static final ParseField NAME = new ParseField("mean_squared_error"); public static final ParseField NAME = new ParseField(LossFunction.MSE.toString());
private static final String PAINLESS_TEMPLATE = private static final String PAINLESS_TEMPLATE =
"def diff = doc[''{0}''].value - doc[''{1}''].value;" + "def diff = doc[''{0}''].value - doc[''{1}''].value;" +

View File

@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -42,7 +43,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
*/ */
public class MeanSquaredLogarithmicError implements EvaluationMetric { public class MeanSquaredLogarithmicError implements EvaluationMetric {
public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error"); public static final ParseField NAME = new ParseField(LossFunction.MSLE.toString());
public static final ParseField OFFSET = new ParseField("offset"); public static final ParseField OFFSET = new ParseField("offset");
private static final double DEFAULT_OFFSET = 1.0; private static final double DEFAULT_OFFSET = 1.0;

View File

@ -17,7 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Preci
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import java.util.Arrays; import java.util.Arrays;
@ -41,7 +41,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
MulticlassConfusionMatrixResultTests.createRandom(), MulticlassConfusionMatrixResultTests.createRandom(),
new MeanSquaredError.Result(randomDouble()), new MeanSquaredError.Result(randomDouble()),
new MeanSquaredLogarithmicError.Result(randomDouble()), new MeanSquaredLogarithmicError.Result(randomDouble()),
new PseudoHuber.Result(randomDouble()), new Huber.Result(randomDouble()),
new RSquared.Result(randomDouble())); new RSquared.Result(randomDouble()));
return new Response(evaluationName, randomSubsetOf(metrics)); return new Response(evaluationName, randomSubsetOf(metrics));
} }

View File

@ -19,37 +19,37 @@ import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue; import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> { public class HuberTests extends AbstractSerializingTestCase<Huber> {
@Override @Override
protected PseudoHuber doParseInstance(XContentParser parser) throws IOException { protected Huber doParseInstance(XContentParser parser) throws IOException {
return PseudoHuber.fromXContent(parser); return Huber.fromXContent(parser);
} }
@Override @Override
protected PseudoHuber createTestInstance() { protected Huber createTestInstance() {
return createRandom(); return createRandom();
} }
@Override @Override
protected Writeable.Reader<PseudoHuber> instanceReader() { protected Writeable.Reader<Huber> instanceReader() {
return PseudoHuber::new; return Huber::new;
} }
public static PseudoHuber createRandom() { public static Huber createRandom() {
return new PseudoHuber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null); return new Huber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
} }
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_pseudo_huber", 0.8123), mockSingleValue("regression_huber", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null); Huber huber = new Huber((Double) null);
pseudoHuber.process(aggs); huber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get(); EvaluationMetricResult result = huber.getResult().get();
String expected = "{\"value\":0.8123}"; String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected)); assertThat(Strings.toString(result), equalTo(expected));
} }
@ -59,10 +59,10 @@ public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
mockSingleValue("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null); Huber huber = new Huber((Double) null);
pseudoHuber.process(aggs); huber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get(); EvaluationMetricResult result = huber.getResult().get();
assertThat(result, equalTo(new PseudoHuber.Result(0.0))); assertThat(result, equalTo(new Huber.Result(0.0)));
} }
} }

View File

@ -42,7 +42,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_mean_squared_error", 0.8123), mockSingleValue("regression_mse", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));

View File

@ -42,7 +42,7 @@ public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCas
public void testEvaluate() { public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList( Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_mean_squared_logarithmic_error", 0.8123), mockSingleValue("regression_msle", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377) mockSingleValue("some_other_single_metric_agg", 0.2377)
)); ));

View File

@ -11,9 +11,9 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.junit.After; import org.junit.After;
@ -105,18 +105,18 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6)); assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
} }
public void testEvaluate_PseudoHuber() { public void testEvaluate_Huber() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse = EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame( evaluateDataFrame(
HOUSES_DATA_INDEX, HOUSES_DATA_INDEX,
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new PseudoHuber((Double) null)))); new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new Huber((Double) null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName())); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1)); assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
PseudoHuber.Result pseudoHuberResult = (PseudoHuber.Result) evaluateDataFrameResponse.getMetrics().get(0); Huber.Result huberResult = (Huber.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuber.NAME.getPreferredName())); assertThat(huberResult.getMetricName(), equalTo(Huber.NAME.getPreferredName()));
assertThat(pseudoHuberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6)); assertThat(huberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
} }
public void testEvaluate_RSquared() { public void testEvaluate_RSquared() {

View File

@ -841,15 +841,15 @@ setup:
"regression": { "regression": {
"actual_field": "regression_field_act", "actual_field": "regression_field_act",
"predicted_field": "regression_field_pred", "predicted_field": "regression_field_pred",
"metrics": { "mean_squared_error": {} } "metrics": { "mse": {} }
} }
} }
} }
- match: { regression.mean_squared_error.value: 28.67749840974834 } - match: { regression.mse.value: 28.67749840974834 }
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.msle.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value - is_false: regression.huber.value
--- ---
"Test regression mean_squared_logarithmic_error": "Test regression mean_squared_logarithmic_error":
- do: - do:
@ -861,17 +861,17 @@ setup:
"regression": { "regression": {
"actual_field": "regression_field_act", "actual_field": "regression_field_act",
"predicted_field": "regression_field_pred", "predicted_field": "regression_field_pred",
"metrics": { "mean_squared_logarithmic_error": { "offset": 6.0 } } "metrics": { "msle": { "offset": 6.0 } }
} }
} }
} }
- match: { regression.mean_squared_logarithmic_error.value: 0.08680568028334916 } - match: { regression.msle.value: 0.08680568028334916 }
- is_false: regression.mean_squared_error.value - is_false: regression.mse.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value - is_false: regression.huber.value
--- ---
"Test regression pseudo_huber": "Test regression huber":
- do: - do:
ml.evaluate_data_frame: ml.evaluate_data_frame:
body: > body: >
@ -881,14 +881,14 @@ setup:
"regression": { "regression": {
"actual_field": "regression_field_act", "actual_field": "regression_field_act",
"predicted_field": "regression_field_pred", "predicted_field": "regression_field_pred",
"metrics": { "pseudo_huber": { "delta": 2.0 } } "metrics": { "huber": { "delta": 2.0 } }
} }
} }
} }
- match: { regression.pseudo_huber.value: 3.5088110471730145 } - match: { regression.huber.value: 3.5088110471730145 }
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.msle.value
- is_false: regression.mean_squared_error.value - is_false: regression.mse.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
--- ---
"Test regression r_squared": "Test regression r_squared":
@ -906,9 +906,9 @@ setup:
} }
} }
- match: { regression.r_squared.value: 0.8551031778603486 } - match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_error - is_false: regression.mse
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.msle.value
- is_false: regression.pseudo_huber.value - is_false: regression.huber.value
--- ---
"Test regression with null metrics": "Test regression with null metrics":
@ -925,9 +925,10 @@ setup:
} }
} }
- match: { regression.mean_squared_error.value: 28.67749840974834 } - match: { regression.mse.value: 28.67749840974834 }
- match: { regression.r_squared.value: 0.8551031778603486 } - match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.msle.value
- is_false: regression.huber.value
--- ---
"Test regression given missing actual_field": "Test regression given missing actual_field":
- do: - do: