Rename regression evaluation metrics to make the names consistent with loss functions (#58887) (#58927)
This commit is contained in:
parent
6aa669c8bb
commit
751e84e4c8
|
@ -23,7 +23,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||||
|
@ -105,8 +105,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
MeanSquaredLogarithmicErrorMetric::fromXContent),
|
MeanSquaredLogarithmicErrorMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)),
|
||||||
PseudoHuberMetric::fromXContent),
|
HuberMetric::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.class,
|
EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
||||||
|
@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
|
MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, HuberMetric.NAME)),
|
||||||
PseudoHuberMetric.Result::fromXContent),
|
HuberMetric.Result::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(
|
new NamedXContentRegistry.Entry(
|
||||||
EvaluationMetric.Result.class,
|
EvaluationMetric.Result.class,
|
||||||
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
|
@ -34,30 +35,30 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
|
||||||
/**
|
/**
|
||||||
* Calculates the pseudo Huber loss function.
|
* Calculates the pseudo Huber loss function.
|
||||||
*
|
*
|
||||||
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
|
* equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
|
||||||
* where: a = y - y´
|
* where: a = y - y´
|
||||||
* δ - parameter that controls the steepness
|
* δ - parameter that controls the steepness
|
||||||
*/
|
*/
|
||||||
public class PseudoHuberMetric implements EvaluationMetric {
|
public class HuberMetric implements EvaluationMetric {
|
||||||
|
|
||||||
public static final String NAME = "pseudo_huber";
|
public static final String NAME = LossFunction.HUBER.toString();
|
||||||
|
|
||||||
public static final ParseField DELTA = new ParseField("delta");
|
public static final ParseField DELTA = new ParseField("delta");
|
||||||
|
|
||||||
private static final ConstructingObjectParser<PseudoHuberMetric, Void> PARSER =
|
private static final ConstructingObjectParser<HuberMetric, Void> PARSER =
|
||||||
new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0]));
|
new ConstructingObjectParser<>(NAME, true, args -> new HuberMetric((Double) args[0]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareDouble(optionalConstructorArg(), DELTA);
|
PARSER.declareDouble(optionalConstructorArg(), DELTA);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PseudoHuberMetric fromXContent(XContentParser parser) {
|
public static HuberMetric fromXContent(XContentParser parser) {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final Double delta;
|
private final Double delta;
|
||||||
|
|
||||||
public PseudoHuberMetric(@Nullable Double delta) {
|
public HuberMetric(@Nullable Double delta) {
|
||||||
this.delta = delta;
|
this.delta = delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +81,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
PseudoHuberMetric that = (PseudoHuberMetric) o;
|
HuberMetric that = (HuberMetric) o;
|
||||||
return Objects.equals(this.delta, that.delta);
|
return Objects.equals(this.delta, that.delta);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
@ -37,7 +38,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
|
||||||
*/
|
*/
|
||||||
public class MeanSquaredErrorMetric implements EvaluationMetric {
|
public class MeanSquaredErrorMetric implements EvaluationMetric {
|
||||||
|
|
||||||
public static final String NAME = "mean_squared_error";
|
public static final String NAME = LossFunction.MSE.toString();
|
||||||
|
|
||||||
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new);
|
private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new);
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.dataframe.Regression.LossFunction;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
|
@ -39,7 +40,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
|
||||||
*/
|
*/
|
||||||
public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
|
public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
|
||||||
|
|
||||||
public static final String NAME = "mean_squared_logarithmic_error";
|
public static final String NAME = LossFunction.MSLE.toString();
|
||||||
|
|
||||||
public static final ParseField OFFSET = new ParseField("offset");
|
public static final ParseField OFFSET = new ParseField("offset");
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||||
|
@ -1889,7 +1889,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
predictedRegression,
|
predictedRegression,
|
||||||
new MeanSquaredErrorMetric(),
|
new MeanSquaredErrorMetric(),
|
||||||
new MeanSquaredLogarithmicErrorMetric(1.0),
|
new MeanSquaredLogarithmicErrorMetric(1.0),
|
||||||
new PseudoHuberMetric(1.0),
|
new HuberMetric(1.0),
|
||||||
new RSquaredMetric()));
|
new RSquaredMetric()));
|
||||||
|
|
||||||
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
||||||
|
@ -1906,9 +1906,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
|
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
|
||||||
assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
|
assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
|
||||||
|
|
||||||
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
|
HuberMetric.Result huberResult = evaluateDataFrameResponse.getMetricByName(HuberMetric.NAME);
|
||||||
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));
|
assertThat(huberResult.getMetricName(), equalTo(HuberMetric.NAME));
|
||||||
assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
|
assertThat(huberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
|
||||||
|
|
||||||
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
|
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
|
||||||
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
|
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
|
||||||
|
|
|
@ -62,7 +62,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||||
|
@ -765,7 +765,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
|
registeredMetricName(Regression.NAME, HuberMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||||
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||||
assertThat(names,
|
assertThat(names,
|
||||||
|
@ -782,7 +782,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
||||||
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
|
registeredMetricName(Regression.NAME, HuberMetric.NAME),
|
||||||
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
|
||||||
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
|
||||||
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
|
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
|
||||||
|
|
|
@ -162,7 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||||
|
@ -3573,7 +3573,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
// Evaluation metrics // <4>
|
// Evaluation metrics // <4>
|
||||||
new MeanSquaredErrorMetric(), // <5>
|
new MeanSquaredErrorMetric(), // <5>
|
||||||
new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
|
new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
|
||||||
new PseudoHuberMetric(1.0), // <7>
|
new HuberMetric(1.0), // <7>
|
||||||
new RSquaredMetric()); // <8>
|
new RSquaredMetric()); // <8>
|
||||||
// end::evaluate-data-frame-evaluation-regression
|
// end::evaluate-data-frame-evaluation-regression
|
||||||
|
|
||||||
|
@ -3588,8 +3588,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
|
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
|
||||||
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
|
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
|
||||||
|
|
||||||
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
|
HuberMetric.Result huberResult = response.getMetricByName(HuberMetric.NAME); // <5>
|
||||||
double pseudoHuber = pseudoHuberResult.getValue(); // <6>
|
double huber = huberResult.getValue(); // <6>
|
||||||
|
|
||||||
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
|
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
|
||||||
double rSquared = rSquaredResult.getValue(); // <8>
|
double rSquared = rSquaredResult.getValue(); // <8>
|
||||||
|
@ -3597,7 +3597,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
|
|
||||||
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
|
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
|
||||||
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
|
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
|
||||||
assertThat(pseudoHuber, closeTo(0.01, 1e-3));
|
assertThat(huber, closeTo(0.01, 1e-3));
|
||||||
assertThat(rSquared, closeTo(0.941, 1e-3));
|
assertThat(rSquared, closeTo(0.941, 1e-3));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,20 +25,20 @@ import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class PseudoHuberMetricResultTests extends AbstractXContentTestCase<PseudoHuberMetric.Result> {
|
public class HuberMetricResultTests extends AbstractXContentTestCase<HuberMetric.Result> {
|
||||||
|
|
||||||
public static PseudoHuberMetric.Result randomResult() {
|
public static HuberMetric.Result randomResult() {
|
||||||
return new PseudoHuberMetric.Result(randomDouble());
|
return new HuberMetric.Result(randomDouble());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected PseudoHuberMetric.Result createTestInstance() {
|
protected HuberMetric.Result createTestInstance() {
|
||||||
return randomResult();
|
return randomResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
protected HuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||||
return PseudoHuberMetric.Result.fromXContent(parser);
|
return HuberMetric.Result.fromXContent(parser);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
|
@ -25,7 +25,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuberMetric> {
|
public class HuberMetricTests extends AbstractXContentTestCase<HuberMetric> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NamedXContentRegistry xContentRegistry() {
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
@ -33,13 +33,13 @@ public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuber
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected PseudoHuberMetric createTestInstance() {
|
protected HuberMetric createTestInstance() {
|
||||||
return new PseudoHuberMetric(randomBoolean() ? randomDouble() : null);
|
return new HuberMetric(randomBoolean() ? randomDouble() : null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected PseudoHuberMetric doParseInstance(XContentParser parser) throws IOException {
|
protected HuberMetric doParseInstance(XContentParser parser) throws IOException {
|
||||||
return PseudoHuberMetric.fromXContent(parser);
|
return HuberMetric.fromXContent(parser);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
|
@ -45,7 +45,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||||
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
|
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
metrics.add(new PseudoHuberMetricTests().createTestInstance());
|
metrics.add(new HuberMetricTests().createTestInstance());
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
metrics.add(new RSquaredMetric());
|
metrics.add(new RSquaredMetric());
|
||||||
|
|
|
@ -125,15 +125,15 @@ which outputs a prediction of values.
|
||||||
(Optional, object) Specifies the metrics that are used for the evaluation.
|
(Optional, object) Specifies the metrics that are used for the evaluation.
|
||||||
Available metrics:
|
Available metrics:
|
||||||
|
|
||||||
`mean_squared_error`:::
|
`mse`:::
|
||||||
(Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value.
|
(Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value.
|
||||||
For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article].
|
For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article].
|
||||||
|
|
||||||
`mean_squared_logarithmic_error`:::
|
`msle`:::
|
||||||
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
|
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
|
||||||
(`ground truth`) value.
|
(`ground truth`) value.
|
||||||
|
|
||||||
`pseudo_huber`:::
|
`huber`:::
|
||||||
(Optional, object) Pseudo Huber loss function.
|
(Optional, object) Pseudo Huber loss function.
|
||||||
For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article].
|
For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article].
|
||||||
|
|
||||||
|
@ -280,7 +280,7 @@ POST _ml/data_frame/_evaluate
|
||||||
"predicted_field": "ml.price_prediction", <4>
|
"predicted_field": "ml.price_prediction", <4>
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"r_squared": {},
|
"r_squared": {},
|
||||||
"mean_squared_error": {}
|
"mse": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -317,7 +317,7 @@ POST _ml/data_frame/_evaluate
|
||||||
"predicted_field": "ml.G3_prediction", <3>
|
"predicted_field": "ml.G3_prediction", <3>
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"r_squared": {},
|
"r_squared": {},
|
||||||
"mean_squared_error": {}
|
"mse": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -356,7 +356,7 @@ POST _ml/data_frame/_evaluate
|
||||||
"predicted_field": "ml.G3_prediction", <3>
|
"predicted_field": "ml.G3_prediction", <3>
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"r_squared": {},
|
"r_squared": {},
|
||||||
"mean_squared_error": {}
|
"mse": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,9 +12,9 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||||
|
@ -101,8 +101,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
|
||||||
MeanSquaredLogarithmicError::fromXContent),
|
MeanSquaredLogarithmicError::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(Regression.NAME, PseudoHuber.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, Huber.NAME)),
|
||||||
PseudoHuber::fromXContent),
|
Huber::fromXContent),
|
||||||
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
||||||
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
|
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
|
||||||
RSquared::fromXContent)
|
RSquared::fromXContent)
|
||||||
|
@ -156,8 +156,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
|
||||||
MeanSquaredLogarithmicError::new),
|
MeanSquaredLogarithmicError::new),
|
||||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||||
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
|
registeredMetricName(Regression.NAME, Huber.NAME),
|
||||||
PseudoHuber::new),
|
Huber::new),
|
||||||
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
||||||
registeredMetricName(Regression.NAME, RSquared.NAME),
|
registeredMetricName(Regression.NAME, RSquared.NAME),
|
||||||
RSquared::new),
|
RSquared::new),
|
||||||
|
@ -193,8 +193,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
||||||
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
|
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
|
||||||
MeanSquaredLogarithmicError.Result::new),
|
MeanSquaredLogarithmicError.Result::new),
|
||||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||||
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
|
registeredMetricName(Regression.NAME, Huber.NAME),
|
||||||
PseudoHuber.Result::new),
|
Huber.Result::new),
|
||||||
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||||
registeredMetricName(Regression.NAME, RSquared.NAME),
|
registeredMetricName(Regression.NAME, RSquared.NAME),
|
||||||
RSquared.Result::new)
|
RSquared.Result::new)
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
@ -37,13 +38,13 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
|
||||||
/**
|
/**
|
||||||
* Calculates the pseudo Huber loss function.
|
* Calculates the pseudo Huber loss function.
|
||||||
*
|
*
|
||||||
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
|
* equation: huber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
|
||||||
* where: a = y - y´
|
* where: a = y - y´
|
||||||
* δ - parameter that controls the steepness
|
* δ - parameter that controls the steepness
|
||||||
*/
|
*/
|
||||||
public class PseudoHuber implements EvaluationMetric {
|
public class Huber implements EvaluationMetric {
|
||||||
|
|
||||||
public static final ParseField NAME = new ParseField("pseudo_huber");
|
public static final ParseField NAME = new ParseField(LossFunction.HUBER.toString());
|
||||||
|
|
||||||
public static final ParseField DELTA = new ParseField("delta");
|
public static final ParseField DELTA = new ParseField("delta");
|
||||||
private static final double DEFAULT_DELTA = 1.0;
|
private static final double DEFAULT_DELTA = 1.0;
|
||||||
|
@ -58,25 +59,25 @@ public class PseudoHuber implements EvaluationMetric {
|
||||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final ConstructingObjectParser<PseudoHuber, Void> PARSER =
|
private static final ConstructingObjectParser<Huber, Void> PARSER =
|
||||||
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new PseudoHuber((Double) args[0]));
|
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new Huber((Double) args[0]));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareDouble(optionalConstructorArg(), DELTA);
|
PARSER.declareDouble(optionalConstructorArg(), DELTA);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PseudoHuber fromXContent(XContentParser parser) {
|
public static Huber fromXContent(XContentParser parser) {
|
||||||
return PARSER.apply(parser, null);
|
return PARSER.apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final double delta;
|
private final double delta;
|
||||||
private EvaluationMetricResult result;
|
private EvaluationMetricResult result;
|
||||||
|
|
||||||
public PseudoHuber(StreamInput in) throws IOException {
|
public Huber(StreamInput in) throws IOException {
|
||||||
this.delta = in.readDouble();
|
this.delta = in.readDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
public PseudoHuber(@Nullable Double delta) {
|
public Huber(@Nullable Double delta) {
|
||||||
this.delta = delta != null ? delta : DEFAULT_DELTA;
|
this.delta = delta != null ? delta : DEFAULT_DELTA;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +131,7 @@ public class PseudoHuber implements EvaluationMetric {
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
PseudoHuber that = (PseudoHuber) o;
|
Huber that = (Huber) o;
|
||||||
return this.delta == that.delta;
|
return this.delta == that.delta;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
@ -40,7 +41,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
|
||||||
*/
|
*/
|
||||||
public class MeanSquaredError implements EvaluationMetric {
|
public class MeanSquaredError implements EvaluationMetric {
|
||||||
|
|
||||||
public static final ParseField NAME = new ParseField("mean_squared_error");
|
public static final ParseField NAME = new ParseField(LossFunction.MSE.toString());
|
||||||
|
|
||||||
private static final String PAINLESS_TEMPLATE =
|
private static final String PAINLESS_TEMPLATE =
|
||||||
"def diff = doc[''{0}''].value - doc[''{1}''].value;" +
|
"def diff = doc[''{0}''].value - doc[''{1}''].value;" +
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
import org.elasticsearch.search.aggregations.Aggregations;
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
||||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression.LossFunction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
|
||||||
|
@ -42,7 +43,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationN
|
||||||
*/
|
*/
|
||||||
public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
||||||
|
|
||||||
public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error");
|
public static final ParseField NAME = new ParseField(LossFunction.MSLE.toString());
|
||||||
|
|
||||||
public static final ParseField OFFSET = new ParseField("offset");
|
public static final ParseField OFFSET = new ParseField("offset");
|
||||||
private static final double DEFAULT_OFFSET = 1.0;
|
private static final double DEFAULT_OFFSET = 1.0;
|
||||||
|
|
|
@ -17,7 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Preci
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -41,7 +41,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
|
||||||
MulticlassConfusionMatrixResultTests.createRandom(),
|
MulticlassConfusionMatrixResultTests.createRandom(),
|
||||||
new MeanSquaredError.Result(randomDouble()),
|
new MeanSquaredError.Result(randomDouble()),
|
||||||
new MeanSquaredLogarithmicError.Result(randomDouble()),
|
new MeanSquaredLogarithmicError.Result(randomDouble()),
|
||||||
new PseudoHuber.Result(randomDouble()),
|
new Huber.Result(randomDouble()),
|
||||||
new RSquared.Result(randomDouble()));
|
new RSquared.Result(randomDouble()));
|
||||||
return new Response(evaluationName, randomSubsetOf(metrics));
|
return new Response(evaluationName, randomSubsetOf(metrics));
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,37 +19,37 @@ import java.util.Collections;
|
||||||
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
|
public class HuberTests extends AbstractSerializingTestCase<Huber> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected PseudoHuber doParseInstance(XContentParser parser) throws IOException {
|
protected Huber doParseInstance(XContentParser parser) throws IOException {
|
||||||
return PseudoHuber.fromXContent(parser);
|
return Huber.fromXContent(parser);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected PseudoHuber createTestInstance() {
|
protected Huber createTestInstance() {
|
||||||
return createRandom();
|
return createRandom();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<PseudoHuber> instanceReader() {
|
protected Writeable.Reader<Huber> instanceReader() {
|
||||||
return PseudoHuber::new;
|
return Huber::new;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PseudoHuber createRandom() {
|
public static Huber createRandom() {
|
||||||
return new PseudoHuber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
|
return new Huber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockSingleValue("regression_pseudo_huber", 0.8123),
|
mockSingleValue("regression_huber", 0.8123),
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
|
Huber huber = new Huber((Double) null);
|
||||||
pseudoHuber.process(aggs);
|
huber.process(aggs);
|
||||||
|
|
||||||
EvaluationMetricResult result = pseudoHuber.getResult().get();
|
EvaluationMetricResult result = huber.getResult().get();
|
||||||
String expected = "{\"value\":0.8123}";
|
String expected = "{\"value\":0.8123}";
|
||||||
assertThat(Strings.toString(result), equalTo(expected));
|
assertThat(Strings.toString(result), equalTo(expected));
|
||||||
}
|
}
|
||||||
|
@ -59,10 +59,10 @@ public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
|
Huber huber = new Huber((Double) null);
|
||||||
pseudoHuber.process(aggs);
|
huber.process(aggs);
|
||||||
|
|
||||||
EvaluationMetricResult result = pseudoHuber.getResult().get();
|
EvaluationMetricResult result = huber.getResult().get();
|
||||||
assertThat(result, equalTo(new PseudoHuber.Result(0.0)));
|
assertThat(result, equalTo(new Huber.Result(0.0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -42,7 +42,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockSingleValue("regression_mean_squared_error", 0.8123),
|
mockSingleValue("regression_mse", 0.8123),
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCas
|
||||||
|
|
||||||
public void testEvaluate() {
|
public void testEvaluate() {
|
||||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
mockSingleValue("regression_mean_squared_logarithmic_error", 0.8123),
|
mockSingleValue("regression_msle", 0.8123),
|
||||||
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
mockSingleValue("some_other_single_metric_agg", 0.2377)
|
||||||
));
|
));
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,9 @@ import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Huber;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -105,18 +105,18 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
|
||||||
assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
|
assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate_PseudoHuber() {
|
public void testEvaluate_Huber() {
|
||||||
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
|
||||||
evaluateDataFrame(
|
evaluateDataFrame(
|
||||||
HOUSES_DATA_INDEX,
|
HOUSES_DATA_INDEX,
|
||||||
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new PseudoHuber((Double) null))));
|
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new Huber((Double) null))));
|
||||||
|
|
||||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
|
||||||
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
|
||||||
|
|
||||||
PseudoHuber.Result pseudoHuberResult = (PseudoHuber.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
Huber.Result huberResult = (Huber.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||||
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuber.NAME.getPreferredName()));
|
assertThat(huberResult.getMetricName(), equalTo(Huber.NAME.getPreferredName()));
|
||||||
assertThat(pseudoHuberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
|
assertThat(huberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testEvaluate_RSquared() {
|
public void testEvaluate_RSquared() {
|
||||||
|
|
|
@ -841,15 +841,15 @@ setup:
|
||||||
"regression": {
|
"regression": {
|
||||||
"actual_field": "regression_field_act",
|
"actual_field": "regression_field_act",
|
||||||
"predicted_field": "regression_field_pred",
|
"predicted_field": "regression_field_pred",
|
||||||
"metrics": { "mean_squared_error": {} }
|
"metrics": { "mse": {} }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
- match: { regression.mean_squared_error.value: 28.67749840974834 }
|
- match: { regression.mse.value: 28.67749840974834 }
|
||||||
- is_false: regression.mean_squared_logarithmic_error.value
|
- is_false: regression.msle.value
|
||||||
- is_false: regression.r_squared.value
|
- is_false: regression.r_squared.value
|
||||||
- is_false: regression.pseudo_huber.value
|
- is_false: regression.huber.value
|
||||||
---
|
---
|
||||||
"Test regression mean_squared_logarithmic_error":
|
"Test regression mean_squared_logarithmic_error":
|
||||||
- do:
|
- do:
|
||||||
|
@ -861,17 +861,17 @@ setup:
|
||||||
"regression": {
|
"regression": {
|
||||||
"actual_field": "regression_field_act",
|
"actual_field": "regression_field_act",
|
||||||
"predicted_field": "regression_field_pred",
|
"predicted_field": "regression_field_pred",
|
||||||
"metrics": { "mean_squared_logarithmic_error": { "offset": 6.0 } }
|
"metrics": { "msle": { "offset": 6.0 } }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
- match: { regression.mean_squared_logarithmic_error.value: 0.08680568028334916 }
|
- match: { regression.msle.value: 0.08680568028334916 }
|
||||||
- is_false: regression.mean_squared_error.value
|
- is_false: regression.mse.value
|
||||||
- is_false: regression.r_squared.value
|
- is_false: regression.r_squared.value
|
||||||
- is_false: regression.pseudo_huber.value
|
- is_false: regression.huber.value
|
||||||
---
|
---
|
||||||
"Test regression pseudo_huber":
|
"Test regression huber":
|
||||||
- do:
|
- do:
|
||||||
ml.evaluate_data_frame:
|
ml.evaluate_data_frame:
|
||||||
body: >
|
body: >
|
||||||
|
@ -881,14 +881,14 @@ setup:
|
||||||
"regression": {
|
"regression": {
|
||||||
"actual_field": "regression_field_act",
|
"actual_field": "regression_field_act",
|
||||||
"predicted_field": "regression_field_pred",
|
"predicted_field": "regression_field_pred",
|
||||||
"metrics": { "pseudo_huber": { "delta": 2.0 } }
|
"metrics": { "huber": { "delta": 2.0 } }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
- match: { regression.pseudo_huber.value: 3.5088110471730145 }
|
- match: { regression.huber.value: 3.5088110471730145 }
|
||||||
- is_false: regression.mean_squared_logarithmic_error.value
|
- is_false: regression.msle.value
|
||||||
- is_false: regression.mean_squared_error.value
|
- is_false: regression.mse.value
|
||||||
- is_false: regression.r_squared.value
|
- is_false: regression.r_squared.value
|
||||||
---
|
---
|
||||||
"Test regression r_squared":
|
"Test regression r_squared":
|
||||||
|
@ -906,9 +906,9 @@ setup:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||||
- is_false: regression.mean_squared_error
|
- is_false: regression.mse
|
||||||
- is_false: regression.mean_squared_logarithmic_error.value
|
- is_false: regression.msle.value
|
||||||
- is_false: regression.pseudo_huber.value
|
- is_false: regression.huber.value
|
||||||
|
|
||||||
---
|
---
|
||||||
"Test regression with null metrics":
|
"Test regression with null metrics":
|
||||||
|
@ -925,9 +925,10 @@ setup:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
- match: { regression.mean_squared_error.value: 28.67749840974834 }
|
- match: { regression.mse.value: 28.67749840974834 }
|
||||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||||
- is_false: regression.mean_squared_logarithmic_error.value
|
- is_false: regression.msle.value
|
||||||
|
- is_false: regression.huber.value
|
||||||
---
|
---
|
||||||
"Test regression given missing actual_field":
|
"Test regression given missing actual_field":
|
||||||
- do:
|
- do:
|
||||||
|
|
Loading…
Reference in New Issue