Rename "error" field to "value" for consistency between metrics (#58726) (#58870)

This commit is contained in:
Przemysław Witek 2020-07-02 09:08:56 +02:00 committed by GitHub
parent a5a8b4ae1d
commit 8e074c4495
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 80 additions and 85 deletions

View File

@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
@ -49,7 +48,12 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
public MeanSquaredErrorMetric() {} public MeanSquaredErrorMetric() {}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.endObject(); builder.endObject();
return builder; return builder;
@ -68,41 +72,36 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
return Objects.hashCode(NAME); return Objects.hashCode(NAME);
} }
@Override
public String getName() {
return NAME;
}
public static class Result implements EvaluationMetric.Result { public static class Result implements EvaluationMetric.Result {
public static final ParseField ERROR = new ParseField("error"); public static final ParseField VALUE = new ParseField("value");
private final double error; private final double value;
public static Result fromXContent(XContentParser parser) { public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0])); new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static { static {
PARSER.declareDouble(constructorArg(), ERROR); PARSER.declareDouble(constructorArg(), VALUE);
} }
public Result(double error) { public Result(double value) {
this.error = error; this.value = value;
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ERROR.getPreferredName(), error); builder.field(VALUE.getPreferredName(), value);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
public double getError() { public double getValue() {
return error; return value;
} }
@Override @Override
@ -115,12 +114,12 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(that.error, this.error); return this.value == that.value;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(error); return Double.hashCode(value);
} }
} }
} }

View File

@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
@ -67,7 +66,7 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
if (offset != null) { if (offset != null) {
builder.field(OFFSET.getPreferredName(), offset); builder.field(OFFSET.getPreferredName(), offset);
@ -89,36 +88,36 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
return Objects.hash(offset); return Objects.hash(offset);
} }
public static class Result implements EvaluationMetric.Result { public static class Result implements EvaluationMetric.Result {
public static final ParseField ERROR = new ParseField("error"); public static final ParseField VALUE = new ParseField("value");
private final double error; private final double value;
public static Result fromXContent(XContentParser parser) { public static Result fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
} }
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0])); new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static { static {
PARSER.declareDouble(constructorArg(), ERROR); PARSER.declareDouble(constructorArg(), VALUE);
} }
public Result(double error) { public Result(double value) {
this.error = error; this.value = value;
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ERROR.getPreferredName(), error); builder.field(VALUE.getPreferredName(), value);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
public double getError() { public double getValue() {
return error; return value;
} }
@Override @Override
@ -131,12 +130,12 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(that.error, this.error); return this.value == that.value;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(error); return Double.hashCode(value);
} }
} }
} }

View File

@ -89,7 +89,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
return Objects.hash(delta); return Objects.hash(delta);
} }
public static class Result implements EvaluationMetric.Result { public static class Result implements EvaluationMetric.Result {
public static final ParseField VALUE = new ParseField("value"); public static final ParseField VALUE = new ParseField("value");
private final double value; private final double value;
@ -99,7 +99,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
} }
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("pseudo_huber_result", true, args -> new Result((double) args[0])); new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static { static {
PARSER.declareDouble(constructorArg(), VALUE); PARSER.declareDouble(constructorArg(), VALUE);
@ -131,7 +131,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(that.value, this.value); return this.value == that.value;
} }
@Override @Override

View File

@ -50,6 +50,11 @@ public class RSquaredMetric implements EvaluationMetric {
public RSquaredMetric() {} public RSquaredMetric() {}
@Override
public String getName() {
return NAME;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
@ -70,11 +75,6 @@ public class RSquaredMetric implements EvaluationMetric {
return Objects.hashCode(NAME); return Objects.hashCode(NAME);
} }
@Override
public String getName() {
return NAME;
}
public static class Result implements EvaluationMetric.Result { public static class Result implements EvaluationMetric.Result {
public static final ParseField VALUE = new ParseField("value"); public static final ParseField VALUE = new ParseField("value");
@ -85,7 +85,7 @@ public class RSquaredMetric implements EvaluationMetric {
} }
private static final ConstructingObjectParser<Result, Void> PARSER = private static final ConstructingObjectParser<Result, Void> PARSER =
new ConstructingObjectParser<>("r_squared_result", true, args -> new Result((double) args[0])); new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
static { static {
PARSER.declareDouble(constructorArg(), VALUE); PARSER.declareDouble(constructorArg(), VALUE);
@ -117,12 +117,12 @@ public class RSquaredMetric implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result that = (Result) o; Result that = (Result) o;
return Objects.equals(that.value, this.value); return this.value == that.value;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(value); return Double.hashCode(value);
} }
} }
} }

View File

@ -1899,12 +1899,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME); MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME)); assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9)); assertThat(mseResult.getValue(), closeTo(0.061000000, 1e-9));
MeanSquaredLogarithmicErrorMetric.Result msleResult = MeanSquaredLogarithmicErrorMetric.Result msleResult =
evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME);
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME)); assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9)); assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME); PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME)); assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));

View File

@ -3582,11 +3582,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::evaluate-data-frame-results-regression // tag::evaluate-data-frame-results-regression
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1> MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
double meanSquaredError = meanSquaredErrorResult.getError(); // <2> double meanSquaredError = meanSquaredErrorResult.getValue(); // <2>
MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult = MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult =
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3> response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4> double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5> PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
double pseudoHuber = pseudoHuberResult.getValue(); // <6> double pseudoHuber = pseudoHuberResult.getValue(); // <6>

View File

@ -52,7 +52,7 @@ public class MeanSquaredError implements EvaluationMetric {
} }
private static final ObjectParser<MeanSquaredError, Void> PARSER = private static final ObjectParser<MeanSquaredError, Void> PARSER =
new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new); new ObjectParser<>(NAME.getPreferredName(), true, MeanSquaredError::new);
public static MeanSquaredError fromXContent(XContentParser parser) { public static MeanSquaredError fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
@ -99,7 +99,6 @@ public class MeanSquaredError implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
} }
@Override @Override
@ -124,15 +123,15 @@ public class MeanSquaredError implements EvaluationMetric {
public static class Result implements EvaluationMetricResult { public static class Result implements EvaluationMetricResult {
private static final String ERROR = "error"; private static final String VALUE = "value";
private final double error; private final double value;
public Result(double error) { public Result(double value) {
this.error = error; this.value = value;
} }
public Result(StreamInput in) throws IOException { public Result(StreamInput in) throws IOException {
this.error = in.readDouble(); this.value = in.readDouble();
} }
@Override @Override
@ -145,19 +144,19 @@ public class MeanSquaredError implements EvaluationMetric {
return NAME.getPreferredName(); return NAME.getPreferredName();
} }
public double getError() { public double getValue() {
return error; return value;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(error); out.writeDouble(value);
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ERROR, error); builder.field(VALUE, value);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -167,12 +166,12 @@ public class MeanSquaredError implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o; Result other = (Result)o;
return error == other.error; return value == other.value;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hashCode(error); return Double.hashCode(value);
} }
} }
} }

View File

@ -29,7 +29,6 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@ -141,15 +140,15 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
public static class Result implements EvaluationMetricResult { public static class Result implements EvaluationMetricResult {
private static final String ERROR = "error"; private static final String VALUE = "value";
private final double error; private final double value;
public Result(double error) { public Result(double value) {
this.error = error; this.value = value;
} }
public Result(StreamInput in) throws IOException { public Result(StreamInput in) throws IOException {
this.error = in.readDouble(); this.value = in.readDouble();
} }
@Override @Override
@ -162,19 +161,19 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
return NAME.getPreferredName(); return NAME.getPreferredName();
} }
public double getError() { public double getValue() {
return error; return value;
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(error); out.writeDouble(value);
} }
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ERROR, error); builder.field(VALUE, value);
builder.endObject(); builder.endObject();
return builder; return builder;
} }
@ -184,12 +183,12 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
if (this == o) return true; if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false; if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o; Result other = (Result)o;
return error == other.error; return value == other.value;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hashCode(error); return Double.hashCode(value);
} }
} }
} }

View File

@ -52,12 +52,12 @@ public class RSquared implements EvaluationMetric {
"return diff * diff;"; "return diff * diff;";
private static final String SS_RES = "residual_sum_of_squares"; private static final String SS_RES = "residual_sum_of_squares";
private static String buildScript(Object... args) { private static String buildScript(Object...args) {
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
} }
private static final ObjectParser<RSquared, Void> PARSER = private static final ObjectParser<RSquared, Void> PARSER =
new ObjectParser<>("r_squared", true, RSquared::new); new ObjectParser<>(NAME.getPreferredName(), true, RSquared::new);
public static RSquared fromXContent(XContentParser parser) { public static RSquared fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); return PARSER.apply(parser, null);
@ -114,7 +114,6 @@ public class RSquared implements EvaluationMetric {
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
} }
@Override @Override
@ -187,7 +186,7 @@ public class RSquared implements EvaluationMetric {
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hashCode(value); return Double.hashCode(value);
} }
} }
} }

View File

@ -50,7 +50,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
mse.process(aggs); mse.process(aggs);
EvaluationMetricResult result = mse.getResult().get(); EvaluationMetricResult result = mse.getResult().get();
String expected = "{\"error\":0.8123}"; String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected)); assertThat(Strings.toString(result), equalTo(expected));
} }

View File

@ -50,7 +50,7 @@ public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCas
msle.process(aggs); msle.process(aggs);
EvaluationMetricResult result = msle.getResult().get(); EvaluationMetricResult result = msle.getResult().get();
String expected = "{\"error\":0.8123}"; String expected = "{\"value\":0.8123}";
assertThat(Strings.toString(result), equalTo(expected)); assertThat(Strings.toString(result), equalTo(expected));
} }

View File

@ -85,7 +85,7 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0); MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName())); assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName()));
assertThat(mseResult.getError(), equalTo(1000000.0)); assertThat(mseResult.getValue(), equalTo(1000000.0));
} }
public void testEvaluate_MeanSquaredLogarithmicError() { public void testEvaluate_MeanSquaredLogarithmicError() {
@ -102,7 +102,7 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0); MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName())); assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6)); assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
} }
public void testEvaluate_PseudoHuber() { public void testEvaluate_PseudoHuber() {

View File

@ -846,7 +846,7 @@ setup:
} }
} }
- match: { regression.mean_squared_error.error: 28.67749840974834 } - match: { regression.mean_squared_error.value: 28.67749840974834 }
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.mean_squared_logarithmic_error.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value - is_false: regression.pseudo_huber.value
@ -866,7 +866,7 @@ setup:
} }
} }
- match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 } - match: { regression.mean_squared_logarithmic_error.value: 0.08680568028334916 }
- is_false: regression.mean_squared_error.value - is_false: regression.mean_squared_error.value
- is_false: regression.r_squared.value - is_false: regression.r_squared.value
- is_false: regression.pseudo_huber.value - is_false: regression.pseudo_huber.value
@ -925,7 +925,7 @@ setup:
} }
} }
- match: { regression.mean_squared_error.error: 28.67749840974834 } - match: { regression.mean_squared_error.value: 28.67749840974834 }
- match: { regression.r_squared.value: 0.8551031778603486 } - match: { regression.r_squared.value: 0.8551031778603486 }
- is_false: regression.mean_squared_logarithmic_error.value - is_false: regression.mean_squared_logarithmic_error.value
--- ---