diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 30231feb9a7..e61e03cd0be 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase randomFrom(FrequencyEncodingTests.createRandom(), OneHotEncodingTests.createRandom(), diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 616b828ed00..db54d545866 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification] <9> The percentage of training-eligible rows to be used in training. Defaults to 100%. <10> The seed to be used by the random generator that picks which rows are used in training. <11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall. -<12> The number of top classes to be reported in the results. Defaults to 2. +<12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2. <13> Custom feature processors that will create new features for analysis from the included document fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features. diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 88f3ee00264..6c0098d6224 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -134,8 +134,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees] `num_top_classes`:::: (Optional, integer) Defines the number of categories for which the predicted probabilities are -reported. It must be non-negative. If it is greater than the total number of -categories, the API reports all category probabilities. Defaults to 2. +reported. It must be non-negative or -1 (which denotes all categories). If it is +greater than the total number of categories, the API reports all category +probabilities. Defaults to 2. `num_top_feature_importance_values`:::: (Optional, integer) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 6f1777d55af..c88774f1289 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -169,8 +169,9 @@ public class Classification implements DataFrameAnalysis { @Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable List featureProcessors) { - if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { - throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); + if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) { + throw ExceptionsHelper.badRequestException( + "[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName()); } if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index ca963b2e139..3c979071ebc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null)); - assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); + assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null)); - assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); + assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1")); } public void testGetPredictionFieldName() { @@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase destDoc = getDestDoc(config, hit); Map resultsObject = getFieldValue(destDoc, "ml"); assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues))); - assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues); + assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues); // Let's just assert there's both training and non-training results // diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index b5816f80829..5a32dde99ed 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1868,10 +1868,10 @@ setup: } --- -"Test put classification given num_top_classes is less than zero": +"Test put classification given num_top_classes is less than minus one": - do: - catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ + catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/ ml.put_data_frame_analytics: id: "classification-training-percent-is-less-than-one" body: > @@ -1885,7 +1885,7 @@ setup: "analysis": { "classification": { "dependent_variable": "foo", - "num_top_classes": -1 + "num_top_classes": -2 } } } @@ -1894,7 +1894,7 @@ setup: "Test put classification given num_top_classes is greater than 1k": - do: - catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/ + catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/ ml.put_data_frame_analytics: id: "classification-training-percent-is-greater-than-hundred" body: >