[7.x] Implement pseudo Huber loss (PseudoHuber) evaluation metric for regression analysis (#58734) (#58825)

This commit is contained in:
Przemysław Witek 2020-07-01 14:52:06 +02:00 committed by GitHub
parent 822b7421ce
commit 909649dd15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 606 additions and 12 deletions

View File

@ -23,6 +23,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -102,6 +103,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
MeanSquaredLogarithmicErrorMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
PseudoHuberMetric::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
@ -149,6 +154,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME)),
PseudoHuberMetric.Result::fromXContent),
new NamedXContentRegistry.Entry(
EvaluationMetric.Result.class,
new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),

View File

@ -0,0 +1,142 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
/**
* Calculates the pseudo Huber loss function.
*
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´
* δ - parameter that controls the steepness
*/
public class PseudoHuberMetric implements EvaluationMetric {
public static final String NAME = "pseudo_huber";
public static final ParseField DELTA = new ParseField("delta");
private static final ConstructingObjectParser<PseudoHuberMetric, Void> PARSER =
new ConstructingObjectParser<>(NAME, true, args -> new PseudoHuberMetric((Double) args[0]));
static {
PARSER.declareDouble(optionalConstructorArg(), DELTA);
}
public static PseudoHuberMetric fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Double delta;
public PseudoHuberMetric(@Nullable Double delta) {
this.delta = delta;
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (delta != null) {
builder.field(DELTA.getPreferredName(), delta);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PseudoHuberMetric that = (PseudoHuberMetric) o;
return Objects.equals(this.delta, that.delta);
}
@Override
public int hashCode() {
return Objects.hash(delta);
}
public static class Result implements EvaluationMetric.Result {
public static final ParseField VALUE = new ParseField("value");
private final double value;
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("pseudo_huber_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), VALUE);
}
public Result(double value) {
this.value = value;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VALUE.getPreferredName(), value);
builder.endObject();
return builder;
}
public double getValue() {
return value;
}
@Override
public String getMetricName() {
return NAME;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.value, this.value);
}
@Override
public int hashCode() {
return Double.hashCode(value);
}
}
}

View File

@ -143,6 +143,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -1886,12 +1887,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
new Regression(
actualRegression,
predictedRegression,
new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), new RSquaredMetric()));
new MeanSquaredErrorMetric(),
new MeanSquaredLogarithmicErrorMetric(1.0),
new PseudoHuberMetric(1.0),
new RSquaredMetric()));
EvaluateDataFrameResponse evaluateDataFrameResponse =
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(3));
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
@ -1902,6 +1906,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));
assertThat(pseudoHuberResult.getValue(), closeTo(0.029669771640929276, 1e-9));
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));

View File

@ -62,6 +62,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classific
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -702,7 +703,7 @@ public class RestHighLevelClientTests extends ESTestCase {
public void testProvidedNamedXContents() {
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
assertEquals(66, namedXContents.size());
assertEquals(68, namedXContents.size());
Map<Class<?>, Integer> categories = new HashMap<>();
List<String> names = new ArrayList<>();
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@ -749,7 +750,7 @@ public class RestHighLevelClientTests extends ESTestCase {
assertTrue(names.contains(TimeSyncConfig.NAME));
assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
assertThat(names,
hasItems(
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@ -764,8 +765,9 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
assertThat(names,
hasItems(
registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@ -780,6 +782,7 @@ public class RestHighLevelClientTests extends ESTestCase {
registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
registeredMetricName(Regression.NAME, PseudoHuberMetric.NAME),
registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));

View File

@ -162,6 +162,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.PseudoHuberMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@ -3572,7 +3573,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// Evaluation metrics // <4>
new MeanSquaredErrorMetric(), // <5>
new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
new RSquaredMetric()); // <7>
new PseudoHuberMetric(1.0), // <7>
new RSquaredMetric()); // <8>
// end::evaluate-data-frame-evaluation-regression
EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@ -3586,12 +3588,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <5>
double rSquared = rSquaredResult.getValue(); // <6>
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
double pseudoHuber = pseudoHuberResult.getValue(); // <6>
RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <7>
double rSquared = rSquaredResult.getValue(); // <8>
// end::evaluate-data-frame-results-regression
assertThat(meanSquaredError, closeTo(0.021, 1e-3));
assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
assertThat(pseudoHuber, closeTo(0.01, 1e-3));
assertThat(rSquared, closeTo(0.941, 1e-3));
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PseudoHuberMetricResultTests extends AbstractXContentTestCase<PseudoHuberMetric.Result> {
public static PseudoHuberMetric.Result randomResult() {
return new PseudoHuberMetric.Result(randomDouble());
}
@Override
protected PseudoHuberMetric.Result createTestInstance() {
return randomResult();
}
@Override
protected PseudoHuberMetric.Result doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.Result.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PseudoHuberMetricTests extends AbstractXContentTestCase<PseudoHuberMetric> {
@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
}
@Override
protected PseudoHuberMetric createTestInstance() {
return new PseudoHuberMetric(randomBoolean() ? randomDouble() : null);
}
@Override
protected PseudoHuberMetric doParseInstance(XContentParser parser) throws IOException {
return PseudoHuberMetric.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -44,6 +44,9 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
if (randomBoolean()) {
metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
}
if (randomBoolean()) {
metrics.add(new PseudoHuberMetricTests().createTestInstance());
}
if (randomBoolean()) {
metrics.add(new RSquaredMetric());
}

View File

@ -69,7 +69,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-regression]
<4> The remaining parameters are the metrics to be calculated based on the two fields described above
<5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error]
<6> Mean squared logarithmic error
<7> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
<7> https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[Pseudo Huber loss]
<8> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
include::../execution.asciidoc[]
@ -126,5 +127,7 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
<2> Fetching the actual mean squared error value
<3> Fetching mean squared logarithmic error metric by name
<4> Fetching the actual mean squared logarithmic error value
<5> Fetching R squared metric by name
<6> Fetching the actual R squared value
<5> Fetching pseudo Huber loss metric by name
<6> Fetching the actual pseudo Huber loss value
<7> Fetching R squared metric by name
<8> Fetching the actual R squared value

View File

@ -133,6 +133,10 @@ which outputs a prediction of values.
(Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
(`ground truth`) value.
`pseudo_huber`:::
(Optional, object) Pseudo Huber loss function.
For more information, read https://en.wikipedia.org/wiki/Huber_loss#Pseudo-Huber_loss_function[this wiki article].
`r_squared`:::
(Optional, object) Proportion of the variance in the dependent variable that is predictable from the independent variables.
For more information, read https://en.wikipedia.org/wiki/Coefficient_of_determination[this wiki article].

View File

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Class
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@ -99,6 +100,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
MeanSquaredLogarithmicError::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, PseudoHuber.NAME)),
PseudoHuber::fromXContent),
new NamedXContentRegistry.Entry(EvaluationMetric.class,
new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
RSquared::fromXContent)
@ -151,6 +155,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
PseudoHuber::new),
new NamedWriteableRegistry.Entry(EvaluationMetric.class,
registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared::new),
@ -185,6 +192,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
MeanSquaredLogarithmicError.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, PseudoHuber.NAME),
PseudoHuber.Result::new),
new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
registeredMetricName(Regression.NAME, RSquared.NAME),
RSquared.Result::new)

View File

@ -0,0 +1,195 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
/**
* Calculates the pseudo Huber loss function.
*
* equation: pseudohuber = 1/n * Σ(δ^2 * sqrt(1 + a^2 / δ^2) - 1)
* where: a = y - y´
* δ - parameter that controls the steepness
*/
public class PseudoHuber implements EvaluationMetric {
public static final ParseField NAME = new ParseField("pseudo_huber");
public static final ParseField DELTA = new ParseField("delta");
private static final double DEFAULT_DELTA = 1.0;
private static final String PAINLESS_TEMPLATE =
"def a = doc[''{0}''].value - doc[''{1}''].value;" +
"def delta2 = {2};" +
"return delta2 * (Math.sqrt(1.0 + Math.pow(a, 2) / delta2) - 1.0);";
private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
}
private static final ConstructingObjectParser<PseudoHuber, Void> PARSER =
new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new PseudoHuber((Double) args[0]));
static {
PARSER.declareDouble(optionalConstructorArg(), DELTA);
}
public static PseudoHuber fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final double delta;
private EvaluationMetricResult result;
public PseudoHuber(StreamInput in) throws IOException {
this.delta = in.readDouble();
}
public PseudoHuber(@Nullable Double delta) {
this.delta = delta != null ? delta : DEFAULT_DELTA;
}
@Override
public String getName() {
return NAME.getPreferredName();
}
@Override
public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
String actualField,
String predictedField) {
if (result != null) {
return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
}
return Tuple.tuple(
Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, delta * delta)))),
Collections.emptyList());
}
@Override
public void process(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
result = value == null ? new Result(0.0) : new Result(value.value());
}
@Override
public Optional<EvaluationMetricResult> getResult() {
return Optional.ofNullable(result);
}
@Override
public String getWriteableName() {
return registeredMetricName(Regression.NAME, NAME);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(delta);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DELTA.getPreferredName(), delta);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PseudoHuber that = (PseudoHuber) o;
return this.delta == that.delta;
}
@Override
public int hashCode() {
return Double.hashCode(delta);
}
public static class Result implements EvaluationMetricResult {
private static final String VALUE = "value";
private final double value;
public Result(double value) {
this.value = value;
}
public Result(StreamInput in) throws IOException {
this.value = in.readDouble();
}
@Override
public String getWriteableName() {
return registeredMetricName(Regression.NAME, NAME);
}
@Override
public String getMetricName() {
return NAME.getPreferredName();
}
public double getValue() {
return value;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VALUE, value);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return value == other.value;
}
@Override
public int hashCode() {
return Double.hashCode(value);
}
}
}

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Preci
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import java.util.Arrays;
@ -40,6 +41,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
MulticlassConfusionMatrixResultTests.createRandom(),
new MeanSquaredError.Result(randomDouble()),
new MeanSquaredLogarithmicError.Result(randomDouble()),
new PseudoHuber.Result(randomDouble()),
new RSquared.Result(randomDouble()));
return new Response(evaluationName, randomSubsetOf(metrics));
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
import static org.hamcrest.Matchers.equalTo;
public class PseudoHuberTests extends AbstractSerializingTestCase<PseudoHuber> {
@Override
protected PseudoHuber doParseInstance(XContentParser parser) throws IOException {
return PseudoHuber.fromXContent(parser);
}
@Override
protected PseudoHuber createTestInstance() {
return createRandom();
}
@Override
protected Writeable.Reader<PseudoHuber> instanceReader() {
return PseudoHuber::new;
}
public static PseudoHuber createRandom() {
return new PseudoHuber(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
}
public void testEvaluate() {
Aggregations aggs = new Aggregations(Arrays.asList(
mockSingleValue("regression_pseudo_huber", 0.8123),
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
pseudoHuber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get();
String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}
public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList(
mockSingleValue("some_other_single_metric_agg", 0.2377)
));
PseudoHuber pseudoHuber = new PseudoHuber((Double) null);
pseudoHuber.process(aggs);
EvaluationMetricResult result = pseudoHuber.getResult().get();
assertThat(result, equalTo(new PseudoHuber.Result(0.0)));
}
}

View File

@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.PseudoHuber;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
import org.junit.After;
@ -101,7 +102,21 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1001), 2), 10E-6));
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
}
public void testEvaluate_PseudoHuber() {
EvaluateDataFrameAction.Response evaluateDataFrameResponse =
evaluateDataFrame(
HOUSES_DATA_INDEX,
new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, Collections.singletonList(new PseudoHuber((Double) null))));
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
PseudoHuber.Result pseudoHuberResult = (PseudoHuber.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuber.NAME.getPreferredName()));
assertThat(pseudoHuberResult.getValue(), closeTo(Math.sqrt(1000000 + 1) - 1, 10E-6));
}
public void testEvaluate_RSquared() {

View File

@ -849,6 +849,7 @@ setup:
- match: { regression.mean_squared_error.error: 28.67749840974834 }
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
---
"Test regression mean_squared_logarithmic_error":
- do:
@ -868,6 +869,27 @@ setup:
- match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 }
- is_false: regression.mean_squared_error.value
- is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
---
"Test regression pseudo_huber":
- do:
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "regression_field_pred",
"metrics": { "pseudo_huber": { "delta": 2.0 } }
}
}
}
- match: { regression.pseudo_huber.value: 3.5088110471730145 }
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.mean_squared_error.value
- is_false: regression.r_squared.value
---
"Test regression r_squared":
- do:
@ -886,6 +908,8 @@ setup:
- match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_error
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.pseudo_huber.value
---
"Test regression with null metrics":
- do: