This commit is contained in:
parent
a5a8b4ae1d
commit
8e074c4495
|
@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
|
@ -49,7 +48,12 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
|
|||
public MeanSquaredErrorMetric() {}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
@ -68,41 +72,36 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
|
|||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static final ParseField ERROR = new ParseField("error");
|
||||
private final double error;
|
||||
public static final ParseField VALUE = new ParseField("value");
|
||||
private final double value;
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
|
||||
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), ERROR);
|
||||
PARSER.declareDouble(constructorArg(), VALUE);
|
||||
}
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
public Result(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR.getPreferredName(), error);
|
||||
builder.field(VALUE.getPreferredName(), value);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -115,12 +114,12 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.error, this.error);
|
||||
return this.value == that.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(error);
|
||||
return Double.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
|
@ -67,7 +66,7 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
|
|||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (offset != null) {
|
||||
builder.field(OFFSET.getPreferredName(), offset);
|
||||
|
@ -91,34 +90,34 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
|
|||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static final ParseField ERROR = new ParseField("error");
|
||||
private final double error;
|
||||
public static final ParseField VALUE = new ParseField("value");
|
||||
private final double value;
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
|
||||
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), ERROR);
|
||||
PARSER.declareDouble(constructorArg(), VALUE);
|
||||
}
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
public Result(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR.getPreferredName(), error);
|
||||
builder.field(VALUE.getPreferredName(), value);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -131,12 +130,12 @@ public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.error, this.error);
|
||||
return this.value == that.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(error);
|
||||
return Double.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -99,7 +99,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
|
|||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("pseudo_huber_result", true, args -> new Result((double) args[0]));
|
||||
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), VALUE);
|
||||
|
@ -131,7 +131,7 @@ public class PseudoHuberMetric implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.value, this.value);
|
||||
return this.value == that.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -50,6 +50,11 @@ public class RSquaredMetric implements EvaluationMetric {
|
|||
|
||||
public RSquaredMetric() {}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -70,11 +75,6 @@ public class RSquaredMetric implements EvaluationMetric {
|
|||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static final ParseField VALUE = new ParseField("value");
|
||||
|
@ -85,7 +85,7 @@ public class RSquaredMetric implements EvaluationMetric {
|
|||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("r_squared_result", true, args -> new Result((double) args[0]));
|
||||
new ConstructingObjectParser<>(NAME + "_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), VALUE);
|
||||
|
@ -117,12 +117,12 @@ public class RSquaredMetric implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.value, this.value);
|
||||
return this.value == that.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value);
|
||||
return Double.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1899,12 +1899,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
|
||||
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
|
||||
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
|
||||
assertThat(mseResult.getValue(), closeTo(0.061000000, 1e-9));
|
||||
|
||||
MeanSquaredLogarithmicErrorMetric.Result msleResult =
|
||||
evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME);
|
||||
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
|
||||
assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
|
||||
assertThat(msleResult.getValue(), closeTo(0.02759231770210426, 1e-9));
|
||||
|
||||
PseudoHuberMetric.Result pseudoHuberResult = evaluateDataFrameResponse.getMetricByName(PseudoHuberMetric.NAME);
|
||||
assertThat(pseudoHuberResult.getMetricName(), equalTo(PseudoHuberMetric.NAME));
|
||||
|
|
|
@ -3582,11 +3582,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
// tag::evaluate-data-frame-results-regression
|
||||
MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
|
||||
double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
|
||||
double meanSquaredError = meanSquaredErrorResult.getValue(); // <2>
|
||||
|
||||
MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult =
|
||||
response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
|
||||
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
|
||||
double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getValue(); // <4>
|
||||
|
||||
PseudoHuberMetric.Result pseudoHuberResult = response.getMetricByName(PseudoHuberMetric.NAME); // <5>
|
||||
double pseudoHuber = pseudoHuberResult.getValue(); // <6>
|
||||
|
|
|
@ -52,7 +52,7 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
}
|
||||
|
||||
private static final ObjectParser<MeanSquaredError, Void> PARSER =
|
||||
new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new);
|
||||
new ObjectParser<>(NAME.getPreferredName(), true, MeanSquaredError::new);
|
||||
|
||||
public static MeanSquaredError fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
|
@ -99,7 +99,6 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -124,15 +123,15 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String ERROR = "error";
|
||||
private final double error;
|
||||
private static final String VALUE = "value";
|
||||
private final double value;
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
public Result(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.error = in.readDouble();
|
||||
this.value = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -145,19 +144,19 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(error);
|
||||
out.writeDouble(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR, error);
|
||||
builder.field(VALUE, value);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -167,12 +166,12 @@ public class MeanSquaredError implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result other = (Result)o;
|
||||
return error == other.error;
|
||||
return value == other.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(error);
|
||||
return Double.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,7 +29,6 @@ import java.util.Arrays;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
@ -141,15 +140,15 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
|||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String ERROR = "error";
|
||||
private final double error;
|
||||
private static final String VALUE = "value";
|
||||
private final double value;
|
||||
|
||||
public Result(double error) {
|
||||
this.error = error;
|
||||
public Result(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.error = in.readDouble();
|
||||
this.value = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -162,19 +161,19 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
|||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
public double getError() {
|
||||
return error;
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(error);
|
||||
out.writeDouble(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(ERROR, error);
|
||||
builder.field(VALUE, value);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -184,12 +183,12 @@ public class MeanSquaredLogarithmicError implements EvaluationMetric {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result other = (Result)o;
|
||||
return error == other.error;
|
||||
return value == other.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(error);
|
||||
return Double.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,12 +52,12 @@ public class RSquared implements EvaluationMetric {
|
|||
"return diff * diff;";
|
||||
private static final String SS_RES = "residual_sum_of_squares";
|
||||
|
||||
private static String buildScript(Object... args) {
|
||||
private static String buildScript(Object...args) {
|
||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||
}
|
||||
|
||||
private static final ObjectParser<RSquared, Void> PARSER =
|
||||
new ObjectParser<>("r_squared", true, RSquared::new);
|
||||
new ObjectParser<>(NAME.getPreferredName(), true, RSquared::new);
|
||||
|
||||
public static RSquared fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
|
@ -114,7 +114,6 @@ public class RSquared implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -187,7 +186,7 @@ public class RSquared implements EvaluationMetric {
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(value);
|
||||
return Double.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
|||
mse.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = mse.getResult().get();
|
||||
String expected = "{\"error\":0.8123}";
|
||||
String expected = "{\"value\":0.8123}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCas
|
|||
msle.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = msle.getResult().get();
|
||||
String expected = "{\"error\":0.8123}";
|
||||
String expected = "{\"value\":0.8123}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
|
|||
|
||||
MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName()));
|
||||
assertThat(mseResult.getError(), equalTo(1000000.0));
|
||||
assertThat(mseResult.getValue(), equalTo(1000000.0));
|
||||
}
|
||||
|
||||
public void testEvaluate_MeanSquaredLogarithmicError() {
|
||||
|
@ -102,7 +102,7 @@ public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestC
|
|||
|
||||
MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
|
||||
assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
|
||||
assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
|
||||
assertThat(msleResult.getValue(), closeTo(Math.pow(Math.log(1000 + 1), 2), 10E-6));
|
||||
}
|
||||
|
||||
public void testEvaluate_PseudoHuber() {
|
||||
|
|
|
@ -846,7 +846,7 @@ setup:
|
|||
}
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- match: { regression.mean_squared_error.value: 28.67749840974834 }
|
||||
- is_false: regression.mean_squared_logarithmic_error.value
|
||||
- is_false: regression.r_squared.value
|
||||
- is_false: regression.pseudo_huber.value
|
||||
|
@ -866,7 +866,7 @@ setup:
|
|||
}
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 }
|
||||
- match: { regression.mean_squared_logarithmic_error.value: 0.08680568028334916 }
|
||||
- is_false: regression.mean_squared_error.value
|
||||
- is_false: regression.r_squared.value
|
||||
- is_false: regression.pseudo_huber.value
|
||||
|
@ -925,7 +925,7 @@ setup:
|
|||
}
|
||||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- match: { regression.mean_squared_error.value: 28.67749840974834 }
|
||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||
- is_false: regression.mean_squared_logarithmic_error.value
|
||||
---
|
||||
|
|
Loading…
Reference in New Issue