* [ML] better handle empty results when evaluating regression * adding new failure test to ml_security black list * fixing equality check for regression results
This commit is contained in:
parent
686739d456
commit
ba7b677618
|
@ -69,7 +69,7 @@ public class MeanSquaredError implements RegressionMetric {
|
|||
@Override
|
||||
public EvaluationMetricResult evaluate(Aggregations aggs) {
|
||||
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
|
||||
return value == null ? null : new Result(value.value());
|
||||
return value == null ? new Result(0.0) : new Result(value.value());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -137,5 +137,18 @@ public class MeanSquaredError implements RegressionMetric {
|
|||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result other = (Result)o;
|
||||
return error == other.error;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ public class RSquared implements RegressionMetric {
|
|||
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
|
||||
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
|
||||
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
|
||||
null :
|
||||
new Result(0.0) :
|
||||
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
|
||||
}
|
||||
|
||||
|
@ -148,5 +148,18 @@ public class RSquared implements RegressionMetric {
|
|||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Result other = (Result)o;
|
||||
return value == other.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -121,6 +121,12 @@ public class Regression implements Evaluation {
|
|||
@Override
|
||||
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
|
||||
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
|
||||
if (searchResponse.getHits().getTotalHits().value == 0) {
|
||||
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
|
||||
actualField,
|
||||
predictedField));
|
||||
return;
|
||||
}
|
||||
for (RegressionMetric metric : metrics) {
|
||||
results.add(metric.evaluate(searchResponse.getAggregations()));
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
|
|||
for (int i = 0; i < recalls.length; i++) {
|
||||
double threshold = thresholds[i];
|
||||
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
|
||||
Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN));
|
||||
Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
|
||||
long tp = tpAgg.getDocCount();
|
||||
long fn = fnAgg.getDocCount();
|
||||
recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);
|
||||
|
|
|
@ -17,9 +17,7 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -64,7 +62,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
|
|||
|
||||
MeanSquaredError mse = new MeanSquaredError();
|
||||
EvaluationMetricResult result = mse.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
|
||||
}
|
||||
|
||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
||||
|
|
|
@ -18,9 +18,7 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -70,17 +68,18 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
|||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingAggs() {
|
||||
EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
|
||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
EvaluationMetricResult result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
assertThat(result, equalTo(zeroResult));
|
||||
|
||||
aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
||||
|
@ -88,7 +87,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
|||
));
|
||||
|
||||
result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
assertThat(result, equalTo(zeroResult));
|
||||
|
||||
aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
|
||||
|
@ -96,7 +95,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
|||
));
|
||||
|
||||
result = rSquared.evaluate(aggs);
|
||||
assertThat(result, is(nullValue()));
|
||||
assertThat(result, equalTo(zeroResult));
|
||||
}
|
||||
|
||||
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
|
||||
|
|
|
@ -86,6 +86,8 @@ integTest.runner {
|
|||
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
|
||||
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
|
||||
'ml/evaluate_data_frame/Test regression given missing actual_field',
|
||||
'ml/evaluate_data_frame/Test regression given missing predicted_field',
|
||||
'ml/delete_job_force/Test cannot force delete a non-existent job',
|
||||
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
|
||||
'ml/delete_model_snapshot/Test delete snapshot missing job_id',
|
||||
|
|
|
@ -602,3 +602,34 @@ setup:
|
|||
|
||||
- match: { regression.mean_squared_error.error: 28.67749840974834 }
|
||||
- match: { regression.r_squared.value: 0.8551031778603486 }
|
||||
---
|
||||
"Test regression given missing actual_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[missing, regression_field_pred\] fields/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "missing",
|
||||
"predicted_field": "regression_field_pred"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test regression given missing predicted_field":
|
||||
- do:
|
||||
catch: /No documents found containing both \[regression_field_act, missing\] fields/
|
||||
ml.evaluate_data_frame:
|
||||
body: >
|
||||
{
|
||||
"index": "utopia",
|
||||
"evaluation": {
|
||||
"regression": {
|
||||
"actual_field": "regression_field_act",
|
||||
"predicted_field": "missing"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue