Rename "error" field to "value" for consistency between metrics (#58726) (#58870)

This commit is contained in:
Przemysław Witek 2020-07-02 09:08:56 +02:00 committed by GitHub
parent a5a8b4ae1d
commit 8e074c4495
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 80 additions and 85 deletions

View File

@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -49,7 +48,12 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
public MeanSquaredErrorMetric() {}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
@ -68,41 +72,36 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
return Objects.hashCode(NAME);
}
@Override
public String getName() {
return NAME;
}
public static class Result implements EvaluationMetric.Result {
public static final ParseField ERROR = new ParseField("error");
private final double error;
public static final ParseField VALUE = new ParseField("value");
private final double value;
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), ERROR);
PARSER.declareDouble(constructorArg(), VALUE);
}
public Result(double error) {
this.error = error;
public Result(double value) {
this.value = value;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ERROR.getPreferredName(), error);
builder.field(VALUE.getPreferredName(), value);
builder.endObject();
return builder;
}
public double getError() {
return error;
public double getValue() {
return value;
}
@Override
@ -115,12 +114,12 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.error, this.error);
return this.value == that.value;
}
@Override
public int hashCode() {
return Objects.hash(error);
return Double.hashCode(value);
}
}
}

View File

@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
@ -67,7 +66,7 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (offset != null) {
builder.field(OFFSET.getPreferredName(), offset);
@ -91,34 +90,34 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
public static class Result implements EvaluationMetric.Result {
public static final ParseField ERROR = new ParseField("error");
private final double error;
public static final ParseField VALUE = new ParseField("value");
private final double value;
public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), ERROR);
PARSER.declareDouble(constructorArg(), VALUE);
}
public Result(double error) {
this.error = error;
public Result(double value) {
this.value = value;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ERROR.getPreferredName(), error);
builder.field(VALUE.getPreferredName(), value);
builder.endObject();
return builder;
}
public double getError() {
return error;
public double getValue() {
return value;
}
@Override
@ -131,12 +130,12 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.error, this.error);
return this.value == that.value;
}
@Override
public int hashCode() {
return Objects.hash(error);
return Double.hashCode(value);
}
}
}

View File

@ -99,7 +99,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("pseudo_huber_result", true, args -> new Result((double) args[0]));
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), VALUE);
@ -131,7 +131,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.value, this.value);
return this.value == that.value;
}
@Override

View File

@ -50,6 +50,11 @@ public class RSquaredMetric implements EvaluationMetric {
public RSquaredMetric() {}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -70,11 +75,6 @@ public class RSquaredMetric implements EvaluationMetric {
return Objects.hashCode(NAME);
}
@Override
public String getName() {
return NAME;
}
public static class Result implements EvaluationMetric.Result {
public static final ParseField VALUE = new ParseField("value");
@ -85,7 +85,7 @@ public class RSquaredMetric implements EvaluationMetric {
}
private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("r_squared_result", true, args -> new Result((double) args[0]));
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static {
PARSER.declareDouble(constructorArg(), VALUE);
@ -117,12 +117,12 @@ public class RSquaredMetric implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o;
return Objects.equals(that.value, this.value);
return this.value == that.value;
}
@Override
public int hashCode() {
return Objects.hash(value);
return Double.hashCode(value);
}
}
}

View File

@ -1899,12 +1899,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
assertThat(mseResult.getValue(), closeTo(0.061000000, 1e-9));
MeanSquaredLogarithmicErrorMetric.Result msleResult =
evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME);
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));

View File

@ -3582,11 +3582,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::evaluate-data-frame-results-regression
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
double meanSquaredError = meanSquaredErrorResult.getValue(); // <2>
MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult =
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
double pseudoHuber = pseudoHuberResult.getValue(); // <6>

View File

@ -52,7 +52,7 @@ public class MeanSquaredError implements EvaluationMetric {
}
private static final ObjectParser<MeanSquaredError, Void> PARSER =
new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new);
new ObjectParser<>(NAME.getPreferredName(), true, MeanSquaredError::new);
public static MeanSquaredError fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
@ -99,7 +99,6 @@ public class MeanSquaredError implements EvaluationMetric {
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
@ -124,15 +123,15 @@ public class MeanSquaredError implements EvaluationMetric {
public static class Result implements EvaluationMetricResult {
private static final String ERROR = "error";
private final double error;
private static final String VALUE = "value";
private final double value;
public Result(double error) {
this.error = error;
public Result(double value) {
this.value = value;
}
public Result(StreamInput in) throws IOException {
this.error = in.readDouble();
this.value = in.readDouble();
}
@Override
@ -145,19 +144,19 @@ public class MeanSquaredError implements EvaluationMetric {
return NAME.getPreferredName();
}
public double getError() {
return error;
public double getValue() {
return value;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(error);
out.writeDouble(value);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ERROR, error);
builder.field(VALUE, value);
builder.endObject();
return builder;
}
@ -167,12 +166,12 @@ public class MeanSquaredError implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return error == other.error;
return value == other.value;
}
@Override
public int hashCode() {
return Objects.hashCode(error);
return Double.hashCode(value);
}
}
}

View File

@ -29,7 +29,6 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -141,15 +140,15 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
public static class Result implements EvaluationMetricResult {
private static final String ERROR = "error";
private final double error;
private static final String VALUE = "value";
private final double value;
public Result(double error) {
this.error = error;
public Result(double value) {
this.value = value;
}
public Result(StreamInput in) throws IOException {
this.error = in.readDouble();
this.value = in.readDouble();
}
@Override
@ -162,19 +161,19 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
return NAME.getPreferredName();
}
public double getError() {
return error;
public double getValue() {
return value;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(error);
out.writeDouble(value);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ERROR, error);
builder.field(VALUE, value);
builder.endObject();
return builder;
}
@ -184,12 +183,12 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return error == other.error;
return value == other.value;
}
@Override
public int hashCode() {
return Objects.hashCode(error);
return Double.hashCode(value);
}
}
}

View File

@ -52,12 +52,12 @@ public class RSquared implements EvaluationMetric {
"return diff * diff;";
private static final String SS_RES = "residual_sum_of_squares";
private static String buildScript(Object... args) {
private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
}
private static final ObjectParser<RSquared, Void> PARSER =
new ObjectParser<>("r_squared", true, RSquared::new);
new ObjectParser<>(NAME.getPreferredName(), true, RSquared::new);
public static RSquared fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
@ -114,7 +114,6 @@ public class RSquared implements EvaluationMetric {
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
@ -187,7 +186,7 @@ public class RSquared implements EvaluationMetric {
@Override
public int hashCode() {
return Objects.hashCode(value);
return Double.hashCode(value);
}
}
}

View File

@ -50,7 +50,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
mse.process(aggs);
EvaluationMetricResult result = mse.getResult().get();
String expected = "{\"error\":0.8123}";
String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}

View File

@ -50,7 +50,7 @@ public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCas
msle.process(aggs);
EvaluationMetricResult result = msle.getResult().get();
String expected = "{\"error\":0.8123}";
String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected));
}

View File

@ -85,7 +85,7 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName()));
assertThat(mseResult.getError(), equalTo(1000000.0));
assertThat(mseResult.getValue(), equalTo(1000000.0));
}
public void testEvaluate_MeanSquaredLogarithmicError() {
@ -102,7 +102,7 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
}
public void testEvaluate_PseudoHuber() {

View File

@ -846,7 +846,7 @@ setup:
}
}
- match: { regression.mean_squared_error.error: 28.67749840974834 }
- match: { regression.mean_squared_error.value: 28.67749840974834 }
- is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
@ -866,7 +866,7 @@ setup:
}
}
- match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 }
- match: { regression.mean_squared_logarithmic_error.value: 0.08680568028334916 }
- is_false: regression.mean_squared_error.value
- is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value
@ -925,7 +925,7 @@ setup:
}
}
- match: { regression.mean_squared_error.error: 28.67749840974834 }
- match: { regression.mean_squared_error.value: 28.67749840974834 }
- match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_logarithmic_error.value
---