[ML] Allow setting num_top_classes to a special value -1 (#63587) (#63602)

This commit is contained in:
Przemysław Witek 2020-10-13 13:57:50 +02:00 committed by GitHub
parent 652af26369
commit acbd48f834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 25 additions and 18 deletions

View File

@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
.setRandomizeSeed(randomBoolean() ? null : randomLong())
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(-1, 1000))
.setFeatureProcessors(randomBoolean() ? null :
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),

View File

@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
<10> The seed to be used by the random generator that picks which rows are used in training.
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
<12> The number of top classes to be reported in the results. Defaults to 2.
<12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2.
<13> Custom feature processors that will create new features for analysis from the included document
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.

View File

@ -134,8 +134,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
`num_top_classes`::::
(Optional, integer)
Defines the number of categories for which the predicted probabilities are
reported. It must be non-negative. If it is greater than the total number of
categories, the API reports all category probabilities. Defaults to 2.
reported. It must be non-negative or -1 (which denotes all categories). If it is
greater than the total number of categories, the API reports all category
probabilities. Defaults to 2.
`num_top_feature_importance_values`::::
(Optional, integer)

View File

@ -169,8 +169,9 @@ public class Classification implements DataFrameAnalysis {
@Nullable Double trainingPercent,
@Nullable Long randomizeSeed,
@Nullable List<PreProcessor> featureProcessors) {
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
throw ExceptionsHelper.badRequestException(
"[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
}
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());

View File

@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
null : randomFrom(Classification.ClassAssignmentObjective.values());
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
Long randomizeSeed = randomBoolean() ? null : randomLong();
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
@ -218,18 +218,18 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
}
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
}
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
}
public void testGetPredictionFieldName() {
@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(7));
// Special value: num_top_classes == -1
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(-1));
// Boundary condition: num_top_classes == 0
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
assertThat(classification.getNumTopClasses(), equalTo(0));

View File

@ -92,7 +92,7 @@ yamlRestTest {
'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than minus one',
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',

View File

@ -368,7 +368,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
String predictedClassField = dependentVariable + "_prediction";
indexData(sourceIndex, 300, 0, dependentVariable);
int numTopClasses = 2;
int numTopClasses = randomBoolean() ? 2 : -1; // Occasionally it's worth testing the special value -1.
int expectedNumTopClasses = 2;
DataFrameAnalyticsConfig config =
buildAnalytics(
jobId,
@ -392,7 +393,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues)));
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues);
assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues);
// Let's just assert there's both training and non-training results
//

View File

@ -1868,10 +1868,10 @@ setup:
}
---
"Test put classification given num_top_classes is less than zero":
"Test put classification given num_top_classes is less than minus one":
- do:
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-less-than-one"
body: >
@ -1885,7 +1885,7 @@ setup:
"analysis": {
"classification": {
"dependent_variable": "foo",
"num_top_classes": -1
"num_top_classes": -2
}
}
}
@ -1894,7 +1894,7 @@ setup:
"Test put classification given num_top_classes is greater than 1k":
- do:
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
ml.put_data_frame_analytics:
id: "classification-training-percent-is-greater-than-hundred"
body: >