* [ML] Add r_squared eval metric to regression * fixing tests and binarysoftclassification class * Update RSquared.java * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java Co-Authored-By: David Kyle <david.kyle@elastic.co> * removing unnecessary debug test
This commit is contained in:
parent
858dbfc074
commit
2c7ff812da
|
@ -19,6 +19,7 @@
|
|||
package org.elasticsearch.client.ml.dataframe.evaluation;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
|
@ -49,6 +50,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent),
|
||||
// Evaluation metrics results
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
|
||||
|
@ -56,6 +59,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
|
||||
new NamedXContentRegistry.Entry(
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
||||
/**
|
||||
* Calculates R-Squared between two known numerical fields.
|
||||
*
|
||||
* equation: mse = 1 - SSres/SStot
|
||||
* such that,
|
||||
* SSres = Σ(y - y´)^2
|
||||
* SStot = Σ(y - y_mean)^2
|
||||
*/
|
||||
public class RSquaredMetric implements EvaluationMetric {
|
||||
|
||||
public static final String NAME = "r_squared";
|
||||
|
||||
private static final ObjectParser<RSquaredMetric, Void> PARSER =
|
||||
new ObjectParser<>("r_squared", true, RSquaredMetric::new);
|
||||
|
||||
public static RSquaredMetric fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public RSquaredMetric() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
// create static hash code from name as there are currently no unique fields per class instance
|
||||
return Objects.hashCode(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetric.Result {
|
||||
|
||||
public static final ParseField VALUE = new ParseField("value");
|
||||
private final double value;
|
||||
|
||||
public static Result fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
private static final ConstructingObjectParser<Result, Void> PARSER =
|
||||
new ConstructingObjectParser<>("r_squared_result", true, args -> new Result((double) args[0]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(constructorArg(), VALUE);
|
||||
}
|
||||
|
||||
public Result(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VALUE.getPreferredName(), value);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
public double getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result that = (Result) o;
|
||||
return Objects.equals(that.value, this.value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(value);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -84,8 +85,11 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = actualField;
|
||||
this.predictedField = predictedField;
|
||||
this.actualField = Objects.requireNonNull(actualField);
|
||||
this.predictedField = Objects.requireNonNull(predictedField);
|
||||
if (metrics != null) {
|
||||
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
|
||||
}
|
||||
this.metrics = metrics;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -52,6 +53,7 @@ public class BinarySoftClassification implements Evaluation {
|
|||
public static final ConstructingObjectParser<BinarySoftClassification, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
NAME,
|
||||
true,
|
||||
args -> new BinarySoftClassification((String) args[0], (String) args[1], (List<EvaluationMetric>) args[2]));
|
||||
|
||||
static {
|
||||
|
@ -80,6 +82,10 @@ public class BinarySoftClassification implements Evaluation {
|
|||
*/
|
||||
private final List<EvaluationMetric> metrics;
|
||||
|
||||
public BinarySoftClassification(String actualField, String predictedField) {
|
||||
this(actualField, predictedField, (List<EvaluationMetric>)null);
|
||||
}
|
||||
|
||||
public BinarySoftClassification(String actualField, String predictedProbabilityField, EvaluationMetric... metric) {
|
||||
this(actualField, predictedProbabilityField, Arrays.asList(metric));
|
||||
}
|
||||
|
@ -88,7 +94,10 @@ public class BinarySoftClassification implements Evaluation {
|
|||
@Nullable List<EvaluationMetric> metrics) {
|
||||
this.actualField = Objects.requireNonNull(actualField);
|
||||
this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField);
|
||||
this.metrics = Objects.requireNonNull(metrics);
|
||||
if (metrics != null) {
|
||||
metrics.sort(Comparator.comparing(EvaluationMetric::getName));
|
||||
}
|
||||
this.metrics = metrics;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -102,11 +111,13 @@ public class BinarySoftClassification implements Evaluation {
|
|||
builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
|
||||
builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
|
||||
|
||||
if (metrics != null) {
|
||||
builder.startObject(METRICS.getPreferredName());
|
||||
for (EvaluationMetric metric : metrics) {
|
||||
builder.field(metric.getName(), metric);
|
||||
}
|
||||
builder.endObject();
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
|
|
|
@ -124,6 +124,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
|||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
|
@ -1597,16 +1598,21 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
|
||||
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
||||
|
||||
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression));
|
||||
evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
|
||||
new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
||||
|
||||
evaluateDataFrameResponse =
|
||||
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
||||
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
||||
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
|
||||
|
||||
MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
|
||||
assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
|
||||
assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
|
||||
|
||||
RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
|
||||
assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
|
||||
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
|
||||
}
|
||||
|
||||
private static XContentBuilder defaultMappingForTest() throws IOException {
|
||||
|
|
|
@ -61,6 +61,7 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction;
|
|||
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
|
||||
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
||||
|
@ -676,7 +677,7 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
|
||||
public void testProvidedNamedXContents() {
|
||||
List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
|
||||
assertEquals(34, namedXContents.size());
|
||||
assertEquals(36, namedXContents.size());
|
||||
Map<Class<?>, Integer> categories = new HashMap<>();
|
||||
List<String> names = new ArrayList<>();
|
||||
for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
|
||||
|
@ -716,12 +717,22 @@ public class RestHighLevelClientTests extends ESTestCase {
|
|||
assertTrue(names.contains(TimeSyncConfig.NAME));
|
||||
assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
|
||||
assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
|
||||
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
|
||||
assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
|
||||
assertThat(names,
|
||||
hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
|
||||
hasItems(AucRocMetric.NAME,
|
||||
PrecisionMetric.NAME,
|
||||
RecallMetric.NAME,
|
||||
ConfusionMatrixMetric.NAME,
|
||||
MeanSquaredErrorMetric.NAME,
|
||||
RSquaredMetric.NAME));
|
||||
}
|
||||
|
||||
public void testApiNamingConventions() throws Exception {
|
||||
|
|
|
@ -26,7 +26,7 @@ import java.io.IOException;
|
|||
|
||||
public class ConfusionMatrixMetricConfusionMatrixTests extends AbstractXContentTestCase<ConfusionMatrixMetric.ConfusionMatrix> {
|
||||
|
||||
static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() {
|
||||
public static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() {
|
||||
return new ConfusionMatrixMetric.ConfusionMatrix(randomInt(), randomInt(), randomInt(), randomInt());
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RSquaredMetricResultTests extends AbstractXContentTestCase<RSquaredMetric.Result> {
|
||||
|
||||
public static RSquaredMetric.Result randomResult() {
|
||||
return new RSquaredMetric.Result(randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RSquaredMetric.Result createTestInstance() {
|
||||
return randomResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RSquaredMetric.Result doParseInstance(XContentParser parser) throws IOException {
|
||||
return RSquaredMetric.Result.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RSquaredMetricTests extends AbstractXContentTestCase<RSquaredMetric> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RSquaredMetric createTestInstance() {
|
||||
return new RSquaredMetric();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RSquaredMetric doParseInstance(XContentParser parser) throws IOException {
|
||||
return RSquaredMetric.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -18,13 +18,15 @@
|
|||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
||||
|
@ -36,9 +38,16 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||
|
||||
@Override
|
||||
protected Regression createTestInstance() {
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new MeanSquaredErrorMetric());
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new RSquaredMetric());
|
||||
}
|
||||
return randomBoolean() ?
|
||||
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) :
|
||||
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric()));
|
||||
new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -56,4 +65,5 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
|
|||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.dataframe.evaluation.softclassification;
|
||||
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
|
||||
import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
public class BinarySoftClassificationTests extends AbstractXContentTestCase<BinarySoftClassification> {
|
||||
|
||||
@Override
|
||||
protected NamedXContentRegistry xContentRegistry() {
|
||||
return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BinarySoftClassification createTestInstance() {
|
||||
List<EvaluationMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new AucRocMetric(randomBoolean()));
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1,
|
||||
4,
|
||||
Double[]::new,
|
||||
BinarySoftClassificationTests::randomDouble))));
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new RecallMetric(Arrays.asList(randomArray(1,
|
||||
4,
|
||||
Double[]::new,
|
||||
BinarySoftClassificationTests::randomDouble))));
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(new ConfusionMatrixMetric(Arrays.asList(randomArray(1,
|
||||
4,
|
||||
Double[]::new,
|
||||
BinarySoftClassificationTests::randomDouble))));
|
||||
}
|
||||
return randomBoolean() ?
|
||||
new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10)) :
|
||||
new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected BinarySoftClassification doParseInstance(XContentParser parser) throws IOException {
|
||||
return BinarySoftClassification.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||
// allow unknown fields in the root of the object only
|
||||
return field -> !field.isEmpty();
|
||||
}
|
||||
|
||||
}
|
|
@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
||||
|
@ -42,6 +43,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
|
||||
// Regression metrics
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
|
||||
namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
|
||||
|
||||
return namedXContent;
|
||||
}
|
||||
|
@ -66,6 +68,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
||||
RSquared.NAME.getPreferredName(),
|
||||
RSquared::new));
|
||||
|
||||
// Evaluation Metrics Results
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
|
||||
|
@ -77,6 +82,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
|
|||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
MeanSquaredError.NAME.getPreferredName(),
|
||||
MeanSquaredError.Result::new));
|
||||
namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
||||
RSquared.NAME.getPreferredName(),
|
||||
RSquared.Result::new));
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
|
||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.text.MessageFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Calculates R-Squared between two known numerical fields.
|
||||
*
|
||||
* equation: R-Squared = 1 - SSres/SStot
|
||||
* such that,
|
||||
* SSres = Σ(y - y´)^2, The residual sum of squares
|
||||
* SStot = Σ(y - y_mean)^2, The total sum of squares
|
||||
*/
|
||||
public class RSquared implements RegressionMetric {
|
||||
|
||||
public static final ParseField NAME = new ParseField("r_squared");
|
||||
|
||||
private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
|
||||
private static final String SS_RES = "residual_sum_of_squares";
|
||||
|
||||
private static String buildScript(Object... args) {
|
||||
return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
||||
}
|
||||
|
||||
private static final ObjectParser<RSquared, Void> PARSER =
|
||||
new ObjectParser<>("r_squared", true, RSquared::new);
|
||||
|
||||
public static RSquared fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public RSquared(StreamInput in) {
|
||||
|
||||
}
|
||||
|
||||
public RSquared() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetricName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AggregationBuilder> aggs(String actualField, String predictedField) {
|
||||
return Arrays.asList(
|
||||
AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
|
||||
AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
|
||||
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
|
||||
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
|
||||
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
|
||||
null :
|
||||
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
// create static hash code from name as there are currently no unique fields per class instance
|
||||
return Objects.hashCode(NAME.getPreferredName());
|
||||
}
|
||||
|
||||
public static class Result implements EvaluationMetricResult {
|
||||
|
||||
private static final String VALUE = "value";
|
||||
private final double value;
|
||||
|
||||
public Result(double value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public Result(StreamInput in) throws IOException {
|
||||
this.value = in.readDouble();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return NAME.getPreferredName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeDouble(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(VALUE, value);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -94,8 +94,9 @@ public class Regression implements Evaluation {
|
|||
}
|
||||
|
||||
private static List<RegressionMetric> defaultMetrics() {
|
||||
List<RegressionMetric> defaultMetrics = new ArrayList<>(1);
|
||||
List<RegressionMetric> defaultMetrics = new ArrayList<>(2);
|
||||
defaultMetrics.add(new MeanSquaredError());
|
||||
defaultMetrics.add(new RSquared());
|
||||
return defaultMetrics;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
|
||||
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
||||
import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||
|
||||
@Override
|
||||
protected RSquared doParseInstance(XContentParser parser) throws IOException {
|
||||
return RSquared.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected RSquared createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<RSquared> instanceReader() {
|
||||
return RSquared::new;
|
||||
}
|
||||
|
||||
public static RSquared createRandom() {
|
||||
return new RSquared();
|
||||
}
|
||||
|
||||
public void testEvaluate() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("residual_sum_of_squares", 10_111),
|
||||
createExtendedStatsAgg("extended_stats_actual", 155.23, 1000),
|
||||
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
|
||||
String expected = "{\"value\":0.9348643947690524}";
|
||||
assertThat(Strings.toString(result), equalTo(expected));
|
||||
}
|
||||
|
||||
public void testEvaluateWithZeroCount() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("residual_sum_of_squares", 0),
|
||||
createExtendedStatsAgg("extended_stats_actual", 0.0, 0),
|
||||
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingAggs() {
|
||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
|
||||
aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
||||
createSingleMetricAgg("residual_sum_of_squares", 0.2377)
|
||||
));
|
||||
|
||||
result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
|
||||
aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
||||
createExtendedStatsAgg("extended_stats_actual",100, 50)
|
||||
));
|
||||
|
||||
result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
}
|
||||
|
||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
||||
NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
|
||||
when(agg.getName()).thenReturn(name);
|
||||
when(agg.value()).thenReturn(value);
|
||||
return agg;
|
||||
}
|
||||
|
||||
private static ExtendedStats createExtendedStatsAgg(String name, double variance, long count) {
|
||||
ExtendedStats agg = mock(ExtendedStats.class);
|
||||
when(agg.getName()).thenReturn(name);
|
||||
when(agg.getVariance()).thenReturn(variance);
|
||||
when(agg.getCount()).thenReturn(count);
|
||||
return agg;
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
@ -32,8 +33,20 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
}
|
||||
|
||||
public static Regression createRandom() {
|
||||
List<RegressionMetric> metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom());
|
||||
return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics);
|
||||
List<RegressionMetric> metrics = new ArrayList<>();
|
||||
if (randomBoolean()) {
|
||||
metrics.add(MeanSquaredErrorTests.createRandom());
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
metrics.add(RSquaredTests.createRandom());
|
||||
}
|
||||
return new Regression(randomAlphaOfLength(10),
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
metrics.isEmpty() ?
|
||||
null :
|
||||
metrics);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -567,6 +567,24 @@ setup:
|
|||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- is_false: regression.r_squared.value
|
||||
---
|
||||
"Test regression r_squared":
|
||||
- do:
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "regression_field_act",
|
||||
"predicted_field": "regression_field_pred",
|
||||
"metrics": { "r_squared": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||
- is_false: regression.mean_squared_error
|
||||
---
|
||||
"Test regression with null metrics":
|
||||
- do:
|
||||
|
@ -583,3 +601,4 @@ setup:
|
|||
}
|
||||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||
|
|
Loading…
Reference in New Issue