This commit is contained in:
parent
652af26369
commit
acbd48f834
|
@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
|||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
|
||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(-1, 1000))
|
||||
.setFeatureProcessors(randomBoolean() ? null :
|
||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||
OneHotEncodingTests.createRandom(),
|
||||
|
|
|
@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
|||
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
|
||||
<12> The number of top classes to be reported in the results. Defaults to 2.
|
||||
<12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2.
|
||||
<13> Custom feature processors that will create new features for analysis from the included document
|
||||
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
|
||||
|
||||
|
|
|
@ -134,8 +134,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
|
|||
`num_top_classes`::::
|
||||
(Optional, integer)
|
||||
Defines the number of categories for which the predicted probabilities are
|
||||
reported. It must be non-negative. If it is greater than the total number of
|
||||
categories, the API reports all category probabilities. Defaults to 2.
|
||||
reported. It must be non-negative or -1 (which denotes all categories). If it is
|
||||
greater than the total number of categories, the API reports all category
|
||||
probabilities. Defaults to 2.
|
||||
|
||||
`num_top_feature_importance_values`::::
|
||||
(Optional, integer)
|
||||
|
|
|
@ -169,8 +169,9 @@ public class Classification implements DataFrameAnalysis {
|
|||
@Nullable Double trainingPercent,
|
||||
@Nullable Long randomizeSeed,
|
||||
@Nullable List<PreProcessor> featureProcessors) {
|
||||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
||||
if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
|
||||
}
|
||||
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
|
|
|
@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
||||
null : randomFrom(Classification.ClassAssignmentObjective.values());
|
||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||
|
@ -218,18 +218,18 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
||||
public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
||||
}
|
||||
|
||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
||||
}
|
||||
|
||||
public void testGetPredictionFieldName() {
|
||||
|
@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||
|
||||
// Special value: num_top_classes == -1
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(-1));
|
||||
|
||||
// Boundary condition: num_top_classes == 0
|
||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||
|
|
|
@ -92,7 +92,7 @@ yamlRestTest {
|
|||
'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k',
|
||||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
|
||||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than minus one',
|
||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
|
||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
|
||||
|
|
|
@ -368,7 +368,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
String predictedClassField = dependentVariable + "_prediction";
|
||||
indexData(sourceIndex, 300, 0, dependentVariable);
|
||||
|
||||
int numTopClasses = 2;
|
||||
int numTopClasses = randomBoolean() ? 2 : -1; // Occasionally it's worth testing the special value -1.
|
||||
int expectedNumTopClasses = 2;
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
jobId,
|
||||
|
@ -392,7 +393,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues)));
|
||||
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues);
|
||||
assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues);
|
||||
|
||||
// Let's just assert there's both training and non-training results
|
||||
//
|
||||
|
|
|
@ -1868,10 +1868,10 @@ setup:
|
|||
}
|
||||
|
||||
---
|
||||
"Test put classification given num_top_classes is less than zero":
|
||||
"Test put classification given num_top_classes is less than minus one":
|
||||
|
||||
- do:
|
||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
|
||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-less-than-one"
|
||||
body: >
|
||||
|
@ -1885,7 +1885,7 @@ setup:
|
|||
"analysis": {
|
||||
"classification": {
|
||||
"dependent_variable": "foo",
|
||||
"num_top_classes": -1
|
||||
"num_top_classes": -2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1894,7 +1894,7 @@ setup:
|
|||
"Test put classification given num_top_classes is greater than 1k":
|
||||
|
||||
- do:
|
||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
|
||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "classification-training-percent-is-greater-than-hundred"
|
||||
body: >
|
||||
|
|
Loading…
Reference in New Issue