[ML] Allow setting num_top_classes to a special value -1 (#63587) (#63602)

This commit is contained in:
Przemysław Witek 2020-10-13 13:57:50 +02:00 committed by GitHub
parent 652af26369
commit acbd48f834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 25 additions and 18 deletions

View File

@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setRandomizeSeed(randomBoolean() ? null : randomLong()) .setRandomizeSeed(randomBoolean() ? null : randomLong())
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values())) .setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10)) .setNumTopClasses(randomBoolean() ? null : randomIntBetween(-1, 1000))
.setFeatureProcessors(randomBoolean() ? null : .setFeatureProcessors(randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(), Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(), OneHotEncodingTests.createRandom(),

View File

@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. <9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<10> The seed to be used by the random generator that picks which rows are used in training. <10> The seed to be used by the random generator that picks which rows are used in training.
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall. <11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
<12> The number of top classes to be reported in the results. Defaults to 2. <12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2.
<13> Custom feature processors that will create new features for analysis from the included document <13> Custom feature processors that will create new features for analysis from the included document
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features. fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.

View File

@ -134,8 +134,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
`num_top_classes`:::: `num_top_classes`::::
(Optional, integer) (Optional, integer)
Defines the number of categories for which the predicted probabilities are Defines the number of categories for which the predicted probabilities are
reported. It must be non-negative. If it is greater than the total number of reported. It must be non-negative or -1 (which denotes all categories). If it is
categories, the API reports all category probabilities. Defaults to 2. greater than the total number of categories, the API reports all category
probabilities. Defaults to 2.
`num_top_feature_importance_values`:::: `num_top_feature_importance_values`::::
(Optional, integer) (Optional, integer)

View File

@ -169,8 +169,9 @@ public class Classification implements DataFrameAnalysis {
@Nullable Double trainingPercent, @Nullable Double trainingPercent,
@Nullable Long randomizeSeed, @Nullable Long randomizeSeed,
@Nullable List<PreProcessor> featureProcessors) { @Nullable List<PreProcessor> featureProcessors) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); throw ExceptionsHelper.badRequestException(
"[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
} }
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) { if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName()); throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());

View File

@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ? Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
null : randomFrom(Classification.ClassAssignmentObjective.values()); null : randomFrom(Classification.ClassAssignmentObjective.values());
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000); Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
Long randomizeSeed = randomBoolean() ? null : randomLong(); Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective, return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
@ -218,18 +218,18 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]")); assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
} }
public void testConstructor_GivenNumTopClassesIsLessThanZero() { public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null)); () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
} }
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null)); () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
} }
public void testGetPredictionFieldName() { public void testGetPredictionFieldName() {
@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null); Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(7)); assertThat(classification.getNumTopClasses(), equalTo(7));
// Special value: num_top_classes == -1
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(-1));
// Boundary condition: num_top_classes == 0 // Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null); classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(0)); assertThat(classification.getNumTopClasses(), equalTo(0));

View File

@ -92,7 +92,7 @@ yamlRestTest {
'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k', 'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative', 'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one', 'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero', 'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than minus one',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k', 'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero', 'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred', 'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',

View File

@ -368,7 +368,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
String predictedClassField = dependentVariable + "_prediction"; String predictedClassField = dependentVariable + "_prediction";
indexData(sourceIndex, 300, 0, dependentVariable); indexData(sourceIndex, 300, 0, dependentVariable);
int numTopClasses = 2; int numTopClasses = randomBoolean() ? 2 : -1; // Occasionally it's worth testing the special value -1.
int expectedNumTopClasses = 2;
DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig config =
buildAnalytics( buildAnalytics(
jobId, jobId,
@ -392,7 +393,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
Map<String, Object> destDoc = getDestDoc(config, hit); Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml"); Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues))); assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues)));
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues); assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues);
// Let's just assert there's both training and non-training results // Let's just assert there's both training and non-training results
// //

View File

@ -1868,10 +1868,10 @@ setup:
} }
--- ---
"Test put classification given num_top_classes is less than zero": "Test put classification given num_top_classes is less than minus one":
- do: - do:
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
ml.put_data_frame_analytics: ml.put_data_frame_analytics:
id: "classification-training-percent-is-less-than-one" id: "classification-training-percent-is-less-than-one"
body: > body: >
@ -1885,7 +1885,7 @@ setup:
"analysis": { "analysis": {
"classification": { "classification": {
"dependent_variable": "foo", "dependent_variable": "foo",
"num_top_classes": -1 "num_top_classes": -2
} }
} }
} }
@ -1894,7 +1894,7 @@ setup:
"Test put classification given num_top_classes is greater than 1k": "Test put classification given num_top_classes is greater than 1k":
- do: - do:
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
ml.put_data_frame_analytics: ml.put_data_frame_analytics:
id: "classification-training-percent-is-greater-than-hundred" id: "classification-training-percent-is-greater-than-hundred"
body: > body: >