diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 31dd15783d3..a7cdf239083 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -172,8 +172,8 @@ public class Classification implements DataFrameAnalysis { if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); } - if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { - throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); + if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName()); } this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index a8d53c9433f..2180a14eaf0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -130,8 +130,8 @@ public class Regression implements DataFrameAnalysis { @Nullable LossFunction lossFunction, @Nullable Double lossFunctionParameter, @Nullable List featureProcessors) { - if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { - throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); + if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName()); } this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 0df22b225ca..7e8c247bd03 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -93,7 +93,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null)); - assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); + assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); + } + + public void testConstructor_GivenTrainingPercentIsLessThanZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null)); + + assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null)); - assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); + assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index 6fe9a090009..b44c2351729 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -84,7 +84,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null)); - assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); + assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); + } + + public void testConstructor_GivenTrainingPercentIsLessThanZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null)); + + assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { @@ -208,7 +215,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null)); - assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); + assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); } public void testConstructor_GivenLossFunctionParameterIsZero() { diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 93f943d7ddf..e0d9b4afd06 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -79,7 +79,7 @@ yamlRestTest { 'ml/data_frame_analytics_crud/Test put regression given max_trees is greater than 2k', 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative', 'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one', - 'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one', + 'ml/data_frame_analytics_crud/Test put regression given training_percent is less than zero', 'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred', 'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero', 'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative', @@ -94,7 +94,7 @@ yamlRestTest { 'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one', 'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero', 'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k', - 'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one', + 'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero', 'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred', 'ml/estimate_model_memory/Test missing overall cardinality', 'ml/estimate_model_memory/Test missing max bucket cardinality', diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index cd9a381fd89..b5816f80829 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1522,10 +1522,10 @@ setup: } --- -"Test put regression given training_percent is less than one": +"Test put regression given training_percent is less than zero": - do: - catch: /\[training_percent\] must be a double in \[1, 100\]/ + catch: /\[training_percent\] must be a positive double in \(0, 100\]/ ml.put_data_frame_analytics: id: "regression-training-percent-is-less-than-one" body: > @@ -1539,7 +1539,7 @@ setup: "analysis": { "regression": { "dependent_variable": "foo", - "training_percent": 0.999 + "training_percent": -1.0 } } } @@ -1548,7 +1548,7 @@ setup: "Test put regression given training_percent is greater than hundred": - do: - catch: /\[training_percent\] must be a double in \[1, 100\]/ + catch: /\[training_percent\] must be a positive double in \(0, 100\]/ ml.put_data_frame_analytics: id: "regression-training-percent-is-greater-than-hundred" body: > @@ -1914,10 +1914,10 @@ setup: } --- -"Test put classification given training_percent is less than one": +"Test put classification given training_percent is less than zero": - do: - catch: /\[training_percent\] must be a double in \[1, 100\]/ + catch: /\[training_percent\] must be a positive double in \(0, 100\]/ ml.put_data_frame_analytics: id: "classification-training-percent-is-less-than-one" body: > @@ -1931,7 +1931,7 @@ setup: "analysis": { "classification": { "dependent_variable": "foo", - "training_percent": 0.999 + "training_percent": -1.0 } } } @@ -1940,7 +1940,7 @@ setup: "Test put classification given training_percent is greater than hundred": - do: - catch: /\[training_percent\] must be a double in \[1, 100\]/ + catch: /\[training_percent\] must be a positive double in \(0, 100\]/ ml.put_data_frame_analytics: id: "classification-training-percent-is-greater-than-hundred" body: >