From 4be366396840b9d819801a234c36caf801419bf1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 17 Apr 2020 14:45:02 -0400 Subject: [PATCH] [7.x] [ML] fix bugs with prediction field value settings (#55333) (#55394) * [ML] fix bugs with prediction field value settings (#55333) This fixes two unreleased bugs: 1. Prediction value type of `number` might show unexpected classes Analytics created models may have class labels like `1, 5, 10` (or some collection of discrete, whole numbers). These labels are passed to the inference model config in the `classification_labels` field. When the predicted value format is `numeric` it should attempt to see if the classification labels are provided and are numeric. If so, use those. If not, use the underlying value. 2. When supplying an update overwrite, inference was losing the default prediction field value. This is because it was not copied over in the copy ctor in the ClassificationConfig.Builder class. closes #55332 --- .../ml/action/InternalInferModelAction.java | 10 +++++++ .../trainedmodel/ClassificationConfig.java | 1 + .../ClassificationConfigUpdate.java | 8 +++-- .../trainedmodel/PredictionFieldType.java | 14 +++++++++ .../InternalInferModelActionRequestTests.java | 4 +-- .../ClassificationConfigTests.java | 3 ++ .../ClassificationConfigUpdateTests.java | 29 +++++++++++++++++-- .../PredictionFieldTypeTests.java | 3 ++ .../RegressionConfigUpdateTests.java | 24 +++++++++++++-- 9 files changed, 88 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java index abb6cfb06b1..4233110abfe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -137,6 +138,15 @@ public class InternalInferModelAction extends ActionType { - public static RegressionConfigUpdate randomRegressionConfig() { + public static RegressionConfigUpdate randomRegressionConfigUpdate() { return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : randomIntBetween(0, 10)); } @@ -40,9 +41,28 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); } + public void testApply() { + RegressionConfig originalConfig = randomRegressionConfig(); + + assertThat(originalConfig, equalTo(RegressionConfigUpdate.EMPTY_PARAMS.apply(originalConfig))); + + assertThat(new RegressionConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(5).build(), + equalTo(new RegressionConfigUpdate.Builder().setNumTopFeatureImportanceValues(5).build().apply(originalConfig))); + assertThat(new RegressionConfig.Builder() + .setNumTopFeatureImportanceValues(1) + .setResultsField("foo") + .build(), + equalTo(new RegressionConfigUpdate.Builder() + .setNumTopFeatureImportanceValues(1) + .setResultsField("foo") + .build() + .apply(originalConfig) + )); + } + @Override protected RegressionConfigUpdate createTestInstance() { - return randomRegressionConfig(); + return randomRegressionConfigUpdate(); } @Override