This commit is contained in:
parent
652af26369
commit
acbd48f834
|
@ -47,7 +47,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
|
||||||
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
|
||||||
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
.setRandomizeSeed(randomBoolean() ? null : randomLong())
|
||||||
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
|
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
|
||||||
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
|
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(-1, 1000))
|
||||||
.setFeatureProcessors(randomBoolean() ? null :
|
.setFeatureProcessors(randomBoolean() ? null :
|
||||||
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
|
||||||
OneHotEncodingTests.createRandom(),
|
OneHotEncodingTests.createRandom(),
|
||||||
|
|
|
@ -125,7 +125,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
|
||||||
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
|
||||||
<10> The seed to be used by the random generator that picks which rows are used in training.
|
<10> The seed to be used by the random generator that picks which rows are used in training.
|
||||||
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
|
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
|
||||||
<12> The number of top classes to be reported in the results. Defaults to 2.
|
<12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2.
|
||||||
<13> Custom feature processors that will create new features for analysis from the included document
|
<13> Custom feature processors that will create new features for analysis from the included document
|
||||||
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
|
fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
|
||||||
|
|
||||||
|
|
|
@ -134,8 +134,9 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
|
||||||
`num_top_classes`::::
|
`num_top_classes`::::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
Defines the number of categories for which the predicted probabilities are
|
Defines the number of categories for which the predicted probabilities are
|
||||||
reported. It must be non-negative. If it is greater than the total number of
|
reported. It must be non-negative or -1 (which denotes all categories). If it is
|
||||||
categories, the API reports all category probabilities. Defaults to 2.
|
greater than the total number of categories, the API reports all category
|
||||||
|
probabilities. Defaults to 2.
|
||||||
|
|
||||||
`num_top_feature_importance_values`::::
|
`num_top_feature_importance_values`::::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
|
|
|
@ -169,8 +169,9 @@ public class Classification implements DataFrameAnalysis {
|
||||||
@Nullable Double trainingPercent,
|
@Nullable Double trainingPercent,
|
||||||
@Nullable Long randomizeSeed,
|
@Nullable Long randomizeSeed,
|
||||||
@Nullable List<PreProcessor> featureProcessors) {
|
@Nullable List<PreProcessor> featureProcessors) {
|
||||||
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
|
if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
|
||||||
}
|
}
|
||||||
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
|
if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
|
||||||
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
|
throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
|
||||||
|
|
|
@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||||
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
||||||
null : randomFrom(Classification.ClassAssignmentObjective.values());
|
null : randomFrom(Classification.ClassAssignmentObjective.values());
|
||||||
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
|
||||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
||||||
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
||||||
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
||||||
|
@ -218,18 +218,18 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
||||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||||
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetPredictionFieldName() {
|
public void testGetPredictionFieldName() {
|
||||||
|
@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
||||||
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(7));
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
||||||
|
|
||||||
|
// Special value: num_top_classes == -1
|
||||||
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
|
||||||
|
assertThat(classification.getNumTopClasses(), equalTo(-1));
|
||||||
|
|
||||||
// Boundary condition: num_top_classes == 0
|
// Boundary condition: num_top_classes == 0
|
||||||
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
||||||
assertThat(classification.getNumTopClasses(), equalTo(0));
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|
||||||
|
|
|
@ -92,7 +92,7 @@ yamlRestTest {
|
||||||
'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k',
|
'ml/data_frame_analytics_crud/Test put classification given max_trees is greater than 2k',
|
||||||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
|
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is negative',
|
||||||
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
|
'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
|
||||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
|
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than minus one',
|
||||||
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
|
'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
|
||||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
|
'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
|
||||||
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
|
'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
|
||||||
|
|
|
@ -368,7 +368,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
String predictedClassField = dependentVariable + "_prediction";
|
String predictedClassField = dependentVariable + "_prediction";
|
||||||
indexData(sourceIndex, 300, 0, dependentVariable);
|
indexData(sourceIndex, 300, 0, dependentVariable);
|
||||||
|
|
||||||
int numTopClasses = 2;
|
int numTopClasses = randomBoolean() ? 2 : -1; // Occasionally it's worth testing the special value -1.
|
||||||
|
int expectedNumTopClasses = 2;
|
||||||
DataFrameAnalyticsConfig config =
|
DataFrameAnalyticsConfig config =
|
||||||
buildAnalytics(
|
buildAnalytics(
|
||||||
jobId,
|
jobId,
|
||||||
|
@ -392,7 +393,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||||
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues)));
|
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues)));
|
||||||
assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues);
|
assertTopClasses(resultsObject, expectedNumTopClasses, dependentVariable, dependentVariableValues);
|
||||||
|
|
||||||
// Let's just assert there's both training and non-training results
|
// Let's just assert there's both training and non-training results
|
||||||
//
|
//
|
||||||
|
|
|
@ -1868,10 +1868,10 @@ setup:
|
||||||
}
|
}
|
||||||
|
|
||||||
---
|
---
|
||||||
"Test put classification given num_top_classes is less than zero":
|
"Test put classification given num_top_classes is less than minus one":
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
|
catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
|
||||||
ml.put_data_frame_analytics:
|
ml.put_data_frame_analytics:
|
||||||
id: "classification-training-percent-is-less-than-one"
|
id: "classification-training-percent-is-less-than-one"
|
||||||
body: >
|
body: >
|
||||||
|
@ -1885,7 +1885,7 @@ setup:
|
||||||
"analysis": {
|
"analysis": {
|
||||||
"classification": {
|
"classification": {
|
||||||
"dependent_variable": "foo",
|
"dependent_variable": "foo",
|
||||||
"num_top_classes": -1
|
"num_top_classes": -2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1894,7 +1894,7 @@ setup:
|
||||||
"Test put classification given num_top_classes is greater than 1k":
|
"Test put classification given num_top_classes is greater than 1k":
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
catch: /\[num_top_classes\] must be an integer in \[0, 1000\]/
|
catch: /\[num_top_classes\] must be an integer in \[0, 1000\] or a special value -1/
|
||||||
ml.put_data_frame_analytics:
|
ml.put_data_frame_analytics:
|
||||||
id: "classification-training-percent-is-greater-than-hundred"
|
id: "classification-training-percent-is-greater-than-hundred"
|
||||||
body: >
|
body: >
|
||||||
|
|
Loading…
Reference in New Issue