diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java index abb6cfb06b1..4233110abfe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -137,6 +138,15 @@ public class InternalInferModelAction extends ActionType { - public static RegressionConfigUpdate randomRegressionConfig() { + public static RegressionConfigUpdate randomRegressionConfigUpdate() { return new RegressionConfigUpdate(randomBoolean() ? null : randomAlphaOfLength(10), randomBoolean() ? null : randomIntBetween(0, 10)); } @@ -40,9 +41,28 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key].")); } + public void testApply() { + RegressionConfig originalConfig = randomRegressionConfig(); + + assertThat(originalConfig, equalTo(RegressionConfigUpdate.EMPTY_PARAMS.apply(originalConfig))); + + assertThat(new RegressionConfig.Builder(originalConfig).setNumTopFeatureImportanceValues(5).build(), + equalTo(new RegressionConfigUpdate.Builder().setNumTopFeatureImportanceValues(5).build().apply(originalConfig))); + assertThat(new RegressionConfig.Builder() + .setNumTopFeatureImportanceValues(1) + .setResultsField("foo") + .build(), + equalTo(new RegressionConfigUpdate.Builder() + .setNumTopFeatureImportanceValues(1) + .setResultsField("foo") + .build() + .apply(originalConfig) + )); + } + @Override protected RegressionConfigUpdate createTestInstance() { - return randomRegressionConfig(); + return randomRegressionConfigUpdate(); } @Override