Rename regression evaluation metrics to make the names consistent with loss functions (#58887) (#58927)

This commit is contained in:
Przemysław Witek 2020-07-02 17:35:55 +02:00 committed by GitHub
parent 6aa669c8bb
commit 751e84e4c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 115 additions and 108 deletions

View File

@ -23,7 +23,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -105,8 +105,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
MeanSquaredLogarithmicErrorMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
PseudoHuberMetric::fromXContent),
new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)),
HuberMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
PseudoHuberMetric.Result::fromXContent),
new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)),
HuberMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
@ -34,30 +35,30 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
/**
* Calculates the pseudo Huber loss function.
*
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´
* δ - parameter that controls the steepness
*/
public class PseudoHuberMetric implements EvaluationMetric {
public class HuberMetric implements EvaluationMetric {
public static final String NAME = "pseudo_huber";
public static final String NAME = LossFunction.HUBER.toString();
public static final ParseField DELTA = new ParseField("delta");
private static final ConstructingObjectParser<PseudoHuberMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0]));
private static final ConstructingObjectParser<HuberMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new HuberMetric((Double) args[0]));
static {
PARSER.declareDouble(optionalConstructorArg(), DELTA);
}
public static PseudoHuberMetric fromXContent(XContentParser parser) {
public static HuberMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Double delta;
public PseudoHuberMetric(@Nullable Double delta) {
public HuberMetric(@Nullable Double delta) {
this.delta = delta;
}
@ -80,7 +81,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PseudoHuberMetric that = (PseudoHuberMetric) o;
HuberMetric that = (HuberMetric) o;
return Objects.equals(this.delta, that.delta);
}

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -37,7 +38,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
*/
public class MeanSquaredErrorMetric implements EvaluationMetric {
public static final String NAME = "mean_squared_error";
public static final String NAME = LossFunction.MSE.toString();
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new);

View File

@ -18,6 +18,7 @@
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
@ -39,7 +40,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
*/
public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
public static final String NAME = "mean_squared_logarithmic_error";
public static final String NAME = LossFunction.MSLE.toString();
public static final ParseField OFFSET = new ParseField("offset");

View File

@ -143,7 +143,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -1889,7 +1889,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
predictedRegression,
new MeanSquaredErrorMetric(),
new MeanSquaredLogarithmicErrorMetric(1.0),
new PseudoHuberMetric(1.0),
new HuberMetric(1.0),
new RSquaredMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
@ -1906,9 +1906,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));
assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
HuberMetric.Result huberResult = evaluateDataFrameResponse.getMetricByName(HuberMetric.NAME);
assertThat(huberResult.getMetricName(), equalTo(HuberMetric.NAME));
assertThat(huberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));

View File

@ -62,7 +62,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -765,7 +765,7 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names,
@ -782,7 +782,7 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
registeredMetricName(Regression.NAME, HuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));

View File

@ -162,7 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@ -3573,7 +3573,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// Evaluation metrics // <4>
new MeanSquaredErrorMetric(), // <5>
new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
new PseudoHuberMetric(1.0), // <7>
new HuberMetric(1.0), // <7>
new RSquaredMetric()); // <8>
// end::evaluate-data-frame-evaluation-regression
@ -3588,8 +3588,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
double pseudoHuber = pseudoHuberResult.getValue(); // <6>
HuberMetric.Result huberResult = response.getMetricByName(HuberMetric.NAME); // <5>
double huber = huberResult.getValue(); // <6>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
double rSquared = rSquaredResult.getValue(); // <8>
@ -3597,7 +3597,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
assertThat(pseudoHuber, closeTo(0.01, 1e-3));
assertThat(huber, closeTo(0.01, 1e-3));
assertThat(rSquared, closeTo(0.941, 1e-3));
}
}

View File

@ -25,20 +25,20 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PseudoHuberMetricResultTests extends AbstractXContentTestCase<PseudoHuberMetric.Result> {
public class HuberMetricResultTests extends AbstractXContentTestCase<HuberMetric.Result> {
public static PseudoHuberMetric.Result randomResult() {
return new PseudoHuberMetric.Result(randomDouble());
public static HuberMetric.Result randomResult() {
return new HuberMetric.Result(randomDouble());
}
@Override
protected PseudoHuberMetric.Result createTestInstance() {
protected HuberMetric.Result createTestInstance() {
return randomResult();
}
@Override
protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.Result.fromXContent(parser);
protected HuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
return HuberMetric.Result.fromXContent(parser);
}
@Override

View File

@ -25,7 +25,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuberMetric> {
public class HuberMetricTests extends AbstractXContentTestCase<HuberMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
@ -33,13 +33,13 @@ public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuber
}
@Override
protected PseudoHuberMetric createTestInstance() {
return new PseudoHuberMetric(randomBoolean() ? randomDouble() : null);
protected HuberMetric createTestInstance() {
return new HuberMetric(randomBoolean() ? randomDouble() : null);
}
@Override
protected PseudoHuberMetric doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.fromXContent(parser);
protected HuberMetric doParseInstance(XContentParser parser) throws IOException {
return HuberMetric.fromXContent(parser);
}
@Override

View File

@ -45,7 +45,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
}
if (randomBoolean()) {
metrics.add(new PseudoHuberMetricTests().createTestInstance());
metrics.add(new HuberMetricTests().createTestInstance());
}
if (randomBoolean()) {
metrics.add(new RSquaredMetric());

View File

@ -125,15 +125,15 @@ which outputs a prediction of values.
(Optional, object) Specifies the metrics that are used for the evaluation.
Available metrics:
`mean_squared_error`:::
`mse`:::
(Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value.
For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article].
`mean_squared_logarithmic_error`:::
`msle`:::
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
(`ground truth`) value.
`pseudo_huber`:::
`huber`:::
(Optional, object) Pseudo Huber loss function.
For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article].
@ -280,7 +280,7 @@ POST _ml/data_frame/_evaluate
"predicted_field": "ml.price_prediction", <4>
"metrics": {
"r_squared": {},
"mean_squared_error": {}
"mse": {}
}
}
}
@ -317,7 +317,7 @@ POST _ml/data_frame/_evaluate
"predicted_field": "ml.G3_prediction", <3>
"metrics": {
"r_squared": {},
"mean_squared_error": {}
"mse": {}
}
}
}
@ -356,7 +356,7 @@ POST _ml/data_frame/_evaluate
"predicted_field": "ml.G3_prediction", <3>
"metrics": {
"r_squared": {},
"mean_squared_error": {}
"mse": {}
}
}
}

View File

@ -12,9 +12,9 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -101,8 +101,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
MeanSquaredLogarithmicError::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuber.NAME)),
PseudoHuber::fromXContent),
new ParseField(registeredMetricName(Regression.NAME, Huber.NAME)),
Huber::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
RSquared::fromXContent)
@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
PseudoHuber::new),
registeredMetricName(Regression.NAME, Huber.NAME),
Huber::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared::new),
@ -193,8 +193,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
PseudoHuber.Result::new),
registeredMetricName(Regression.NAME, Huber.NAME),
Huber.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared.Result::new)

View File

@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -37,13 +38,13 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
/**
* Calculates the pseudo Huber loss function.
*
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´
* δ - parameter that controls the steepness
*/
public class PseudoHuber implements EvaluationMetric {
public class Huber implements EvaluationMetric {
public static final ParseField NAME = new ParseField("pseudo_huber");
public static final ParseField NAME = new ParseField(LossFunction.HUBER.toString());
public static final ParseField DELTA = new ParseField("delta");
private static final double DEFAULT_DELTA = 1.0;
@ -58,25 +59,25 @@ public class PseudoHuber implements EvaluationMetric {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
}
private static final ConstructingObjectParser<PseudoHuber, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new PseudoHuber((Double) args[0]));
private static final ConstructingObjectParser<Huber, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new Huber((Double) args[0]));
static {
PARSER.declareDouble(optionalConstructorArg(), DELTA);
}
public static PseudoHuber fromXContent(XContentParser parser) {
public static Huber fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final double delta;
private EvaluationMetricResult result;
public PseudoHuber(StreamInput in) throws IOException {
public Huber(StreamInput in) throws IOException {
this.delta = in.readDouble();
}
public PseudoHuber(@Nullable Double delta) {
public Huber(@Nullable Double delta) {
this.delta = delta != null ? delta : DEFAULT_DELTA;
}
@ -130,7 +131,7 @@ public class PseudoHuber implements EvaluationMetric {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PseudoHuber that = (PseudoHuber) o;
Huber that = (Huber) o;
return this.delta == that.delta;
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -40,7 +41,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
*/
public class MeanSquaredError implements EvaluationMetric {
public static final ParseField NAME = new ParseField("mean_squared_error");
public static final ParseField NAME = new ParseField(LossFunction.MSE.toString());
private static final String PAINLESS_TEMPLATE =
"def diff = doc[''{0}''].value - doc[''{1}''].value;" +

View File

@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
@ -42,7 +43,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
*/
public class MeanSquaredLogarithmicError implements EvaluationMetric {
public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error");
public static final ParseField NAME = new ParseField(LossFunction.MSLE.toString());
public static final ParseField OFFSET = new ParseField("offset");
private static final double DEFAULT_OFFSET = 1.0;

View File

@ -17,7 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Preci
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import java.util.Arrays;
@ -41,7 +41,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
MulticlassConfusionMatrixResultTests.createRandom(),
new MeanSquaredError.Result(randomDouble()),
new MeanSquaredLogarithmicError.Result(randomDouble()),
new PseudoHuber.Result(randomDouble()),
new Huber.Result(randomDouble()),
new RSquared.Result(randomDouble()));
return new Response(evaluationName, randomSubsetOf(metrics));
}

View File

@ -19,37 +19,37 @@ import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo;
public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
public class HuberTests extends AbstractSerializingTestCase<Huber> {
@Override
protected PseudoHuber doParseInstance(XContentParser parser) throws IOException {
return PseudoHuber.fromXContent(parser);
protected Huber doParseInstance(XContentParser parser) throws IOException {
return Huber.fromXContent(parser);
}
@Override
protected PseudoHuber createTestInstance() {
protected Huber createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<PseudoHuber> instanceReader() {
return PseudoHuber::new;
protected Writeable.Reader<Huber> instanceReader() {
return Huber::new;
}
public static PseudoHuber createRandom() {
return new PseudoHuber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
public static Huber createRandom() {
return new Huber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
}
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_pseudo_huber", 0.8123),
mockSingleValue("regression_huber", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
pseudoHuber.process(aggs);
Huber huber = new Huber((Double) null);
huber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get();
EvaluationMetricResult result = huber.getResult().get();
String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}
@ -59,10 +59,10 @@ public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
pseudoHuber.process(aggs);
Huber huber = new Huber((Double) null);
huber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get();
assertThat(result, equalTo(new PseudoHuber.Result(0.0)));
EvaluationMetricResult result = huber.getResult().get();
assertThat(result, equalTo(new Huber.Result(0.0)));
}
}

View File

@ -42,7 +42,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_mean_squared_error", 0.8123),
mockSingleValue("regression_mse", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));

View File

@ -42,7 +42,7 @@ public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCas
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_mean_squared_logarithmic_error", 0.8123),
mockSingleValue("regression_msle", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));

View File

@ -11,9 +11,9 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.junit.After;
@ -105,18 +105,18 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
}
public void testEvaluate_PseudoHuber() {
public void testEvaluate_Huber() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
HOUSES_DATA_INDEX,
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new PseudoHuber((Double) null))));
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new Huber((Double) null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
PseudoHuber.Result pseudoHuberResult = (PseudoHuber.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuber.NAME.getPreferredName()));
assertThat(pseudoHuberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
Huber.Result huberResult = (Huber.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(huberResult.getMetricName(), equalTo(Huber.NAME.getPreferredName()));
assertThat(huberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
}
public void testEvaluate_RSquared() {

View File

@ -841,15 +841,15 @@ setup:
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { "mean_squared_error": {} }
"metrics": { "mse": {} }
}
}
}
- match: { regression.mean_squared_error.value: 28.67749840974834 }
- is_false: regression.mean_squared_logarithmic_error.value
- match: { regression.mse.value: 28.67749840974834 }
- is_false: regression.msle.value
- is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
- is_false: regression.huber.value
---
"Test regression mean_squared_logarithmic_error":
- do:
@ -861,17 +861,17 @@ setup:
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { "mean_squared_logarithmic_error": { "offset": 6.0 } }
"metrics": { "msle": { "offset": 6.0 } }
}
}
}
- match: { regression.mean_squared_logarithmic_error.value: 0.08680568028334916 }
- is_false: regression.mean_squared_error.value
- match: { regression.msle.value: 0.08680568028334916 }
- is_false: regression.mse.value
- is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
- is_false: regression.huber.value
---
"Test regression pseudo_huber":
"Test regression huber":
- do:
ml.evaluate_data_frame:
body: >
@ -881,14 +881,14 @@ setup:
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { "pseudo_huber": { "delta": 2.0 } }
"metrics": { "huber": { "delta": 2.0 } }
}
}
}
- match: { regression.pseudo_huber.value: 3.5088110471730145 }
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.mean_squared_error.value
- match: { regression.huber.value: 3.5088110471730145 }
- is_false: regression.msle.value
- is_false: regression.mse.value
- is_false: regression.r_squared.value
---
"Test regression r_squared":
@ -906,9 +906,9 @@ setup:
}
}
- match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_error
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.pseudo_huber.value
- is_false: regression.mse
- is_false: regression.msle.value
- is_false: regression.huber.value
---
"Test regression with null metrics":
@ -925,9 +925,10 @@ setup:
}
}
- match: { regression.mean_squared_error.value: 28.67749840974834 }
- match: { regression.mse.value: 28.67749840974834 }
- match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.msle.value
- is_false: regression.huber.value
---
"Test regression given missing actual_field":
- do: