This changes the valid range of `training_percent` for regression and classification from [1, 100] to (0, 100]. Backport of #61977
This commit is contained in:
parent
3396184ff3
commit
d37f197efd
|
@ -172,8 +172,8 @@ public class Classification implements DataFrameAnalysis {
|
|||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
||||
}
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
|
|
|
@ -130,8 +130,8 @@ public class Regression implements DataFrameAnalysis {
|
|||
@Nullable LossFunction lossFunction,
|
||||
@Nullable Double lossFunctionParameter,
|
||||
@Nullable List<PreProcessor> featureProcessors) {
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
|
||||
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
|
||||
|
|
|
@ -93,7 +93,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
||||
null : randomFrom(Classification.ClassAssignmentObjective.values());
|
||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||
numTopClasses, trainingPercent, randomizeSeed,
|
||||
|
@ -198,19 +198,25 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
public void testConstructor_GivenTrainingPercentIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||
|
|
|
@ -84,7 +84,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
private static Regression createRandom(BoostedTreeParams boostedTreeParams) {
|
||||
String dependentVariableName = randomAlphaOfLength(10);
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
|
||||
Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
|
||||
|
@ -196,11 +196,18 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
}
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
||||
public void testConstructor_GivenTrainingPercentIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
||||
|
@ -208,7 +215,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
|
||||
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenLossFunctionParameterIsZero() {
|
||||
|
|
|
@ -79,7 +79,7 @@ yamlRestTest {
|
|||
'ml/data_frame_analytics_crud/Test put regression given max_trees is greater than 2k',
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than zero',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
|
||||
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero',
|
||||
'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative',
|
||||
|
@ -94,7 +94,7 @@ yamlRestTest {
|
|||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
|
||||
'ml/estimate_model_memory/Test missing overall cardinality',
|
||||
'ml/estimate_model_memory/Test missing max bucket cardinality',
|
||||
|
|
|
@ -1522,10 +1522,10 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test put regression given training_percent is less than one":
|
||||
"Test put regression given training_percent is less than zero":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-training-percent-is-less-than-one"
|
||||
body: >
|
||||
|
@ -1539,7 +1539,7 @@ setup:
|
|||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 0.999
|
||||
"training_percent": -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1548,7 +1548,7 @@ setup:
|
|||
"Test put regression given training_percent is greater than hundred":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-training-percent-is-greater-than-hundred"
|
||||
body: >
|
||||
|
@ -1914,10 +1914,10 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test put classification given training_percent is less than one":
|
||||
"Test put classification given training_percent is less than zero":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-less-than-one"
|
||||
body: >
|
||||
|
@ -1931,7 +1931,7 @@ setup:
|
|||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 0.999
|
||||
"training_percent": -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1940,7 +1940,7 @@ setup:
|
|||
"Test put classification given training_percent is greater than hundred":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
catch: /\[training_percent\] must be a positive double in \(0, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-greater-than-hundred"
|
||||
body: >
|
||||
|
|
Loading…
Reference in New Issue