[7.x][ML] Allow training_percent to be any positive double up to hundred (#61977) (#61990)

This changes the valid range of `training_percent` for regression and
classification from [1, 100] to (0, 100].

Backport of #61977
This commit is contained in:
Dimitris Athanasiou 2020-09-04 17:34:14 +03:00 committed by GitHub
parent 3396184ff3
commit d37f197efd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 38 additions and 25 deletions

View File

@ -172,8 +172,8 @@ public class Classification implements DataFrameAnalysis {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
}
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);

View File

@ -130,8 +130,8 @@ public class Regression implements DataFrameAnalysis {
@Nullable LossFunction lossFunction,
@Nullable Double lossFunctionParameter,
@Nullable List<PreProcessor> featureProcessors) {
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);

View File

@ -93,7 +93,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
null : randomFrom(Classification.ClassAssignmentObjective.values());
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
numTopClasses, trainingPercent, randomizeSeed,
@ -198,19 +198,25 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
}
}
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
public void testConstructor_GivenTrainingPercentIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenNumTopClassesIsLessThanZero() {

View File

@ -84,7 +84,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
private static Regression createRandom(BoostedTreeParams boostedTreeParams) {
String dependentVariableName = randomAlphaOfLength(10);
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
Long randomizeSeed = randomBoolean() ? null : randomLong();
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
@ -196,11 +196,18 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
}
}
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
public void testConstructor_GivenTrainingPercentIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
@ -208,7 +215,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenLossFunctionParameterIsZero() {

View File

@ -79,7 +79,7 @@ yamlRestTest {
'ml/data_frame_analytics_crud/Test put regression given max_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than zero',
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero',
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative',
@ -94,7 +94,7 @@ yamlRestTest {
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one',
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
'ml/estimate_model_memory/Test missing overall cardinality',
'ml/estimate_model_memory/Test missing max bucket cardinality',

View File

@ -1522,10 +1522,10 @@ setup:
}
---
"Test put regression given training_percent is less than one":
"Test put regression given training_percent is less than zero":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
ml.put_data_frame_analytics:
id: "regression-training-percent-is-less-than-one"
body: >
@ -1539,7 +1539,7 @@ setup:
"analysis": {
"regression": {
"dependent_variable": "foo",
"training_percent": 0.999
"training_percent": -1.0
}
}
}
@ -1548,7 +1548,7 @@ setup:
"Test put regression given training_percent is greater than hundred":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
ml.put_data_frame_analytics:
id: "regression-training-percent-is-greater-than-hundred"
body: >
@ -1914,10 +1914,10 @@ setup:
}
---
"Test put classification given training_percent is less than one":
"Test put classification given training_percent is less than zero":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-less-than-one"
body: >
@ -1931,7 +1931,7 @@ setup:
"analysis": {
"classification": {
"dependent_variable": "foo",
"training_percent": 0.999
"training_percent": -1.0
}
}
}
@ -1940,7 +1940,7 @@ setup:
"Test put classification given training_percent is greater than hundred":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-greater-than-hundred"
body: >