From c62fe8c344582a91d2f544e1960471f79015f446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 11 Oct 2019 14:57:08 +0200 Subject: [PATCH] Require that the dependent variable column has at most 2 distinct values in classfication analysis. (#47858) (#47906) --- .../ml/dataframe/analyses/Classification.java | 6 + .../dataframe/analyses/DataFrameAnalysis.java | 5 + .../dataframe/analyses/OutlierDetection.java | 5 + .../ml/dataframe/analyses/Regression.java | 5 + .../analyses/ClassificationTests.java | 7 ++ .../analyses/OutlierDetectionTests.java | 6 + .../dataframe/analyses/RegressionTests.java | 6 + .../ml/integration/ClassificationIT.java | 110 +++++++++--------- .../DataFrameDataExtractorFactory.java | 70 ++++++++++- .../test/ml/start_data_frame_analytics.yml | 51 +++++++- 10 files changed, 208 insertions(+), 63 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 96c03b7692f..199158cceaa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -152,6 +152,12 @@ public class Classification implements DataFrameAnalysis { return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical())); } + @Override + public Map getFieldCardinalityLimits() { + // This restriction is due to the fact that currently the C++ backend only supports binomial classification. + return Collections.singletonMap(dependentVariable, 2L); + } + @Override public boolean supportsMissingValues() { return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 4f89388f912..d23097f5816 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -28,6 +28,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { */ List getRequiredFields(); + /** + * @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields + */ + Map getFieldCardinalityLimits(); + /** * @return {@code true} if this analysis supports data frame rows with missing values */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 47325ffdea8..055c97a511a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -218,6 +218,11 @@ public class OutlierDetection implements DataFrameAnalysis { return Collections.emptyList(); } + @Override + public Map getFieldCardinalityLimits() { + return Collections.emptyMap(); + } + @Override public boolean supportsMissingValues() { return false; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index e804c7d1761..34a93713385 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -139,6 +139,11 @@ public class Regression implements DataFrameAnalysis { return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical())); } + @Override + public Map getFieldCardinalityLimits() { + return Collections.emptyMap(); + } + @Override public boolean supportsMissingValues() { return true; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index e67f2970946..2cc1fae8eee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -13,6 +13,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; public class ClassificationTests extends AbstractSerializingTestCase { @@ -65,4 +68,8 @@ public class ClassificationTests extends AbstractSerializingTestCase { @@ -82,6 +84,10 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase { @@ -66,6 +68,10 @@ public class RegressionTests extends AbstractSerializingTestCase { assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } + public void testFieldCardinalityLimitsIsNonNull() { + assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); + } + public void testGetStateDocId() { Regression regression = createRandom(); assertThat(regression.persistsState(), is(true)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 8a8040f586f..0244c2584e0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.integration; import com.google.common.collect.Ordering; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetResponse; @@ -37,10 +38,10 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo; public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { - private static final String NUMERICAL_FEATURE_FIELD = "feature"; - private static final String DEPENDENT_VARIABLE_FIELD = "variable"; - private static final List NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0)); - private static final List DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow")); + private static final String NUMERICAL_FIELD = "numerical-field"; + private static final String KEYWORD_FIELD = "keyword-field"; + private static final List NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + private static final List KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat")); private String jobId; private String sourceIndex; @@ -53,36 +54,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("classification_single_numeric_feature_and_mixed_data_set"); + indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); - { // Index 350 rows, 300 of them being training rows. - client().admin().indices().prepareCreate(sourceIndex) - .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword") - .get(); - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < 300; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - String value = DEPENDENT_VARIABLE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); - bulkRequestBuilder.add(indexRequest); - } - for (int i = 300; i < 350; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - - IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } - } - - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); registerAnalytics(config); putAnalytics(config); @@ -97,10 +71,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { Map destDoc = getDestDoc(config, hit); Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); - assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); + assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); - assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertThat(resultsObject.containsKey("top_classes"), is(false)); } @@ -117,9 +91,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { initialize("classification_only_training_data_and_training_percent_is_100"); - indexTrainingData(sourceIndex, 300); + indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); registerAnalytics(config); putAnalytics(config); @@ -133,8 +107,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); - assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); + assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(true)); assertThat(resultsObject.containsKey("top_classes"), is(false)); @@ -153,7 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { initialize("classification_only_training_data_and_training_percent_is_50"); - indexTrainingData(sourceIndex, 300); + indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); DataFrameAnalyticsConfig config = buildAnalytics( @@ -161,7 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0)); + new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0)); registerAnalytics(config); putAnalytics(config); @@ -176,8 +150,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); - assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); + assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); // Let's just assert there's both training and non-training results @@ -205,7 +179,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712") public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception { initialize("classification_top_classes_requested"); - indexTrainingData(sourceIndex, 300); + indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES); int numTopClasses = 2; DataFrameAnalyticsConfig config = @@ -214,7 +188,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { sourceIndex, destIndex, null, - new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null)); + new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null)); registerAnalytics(config); putAnalytics(config); @@ -229,8 +203,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { Map destDoc = getDestDoc(config, hit); Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); - assertThat(resultsObject.containsKey("variable_prediction"), is(true)); - assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES))); + assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); + assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertTopClasses(resultsObject, numTopClasses); } @@ -245,25 +219,47 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { "Finished analysis"); } + public void testDependentVariableCardinalityTooHighError() { + initialize("cardinality_too_high"); + indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox")); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); + registerAnalytics(config); + putAnalytics(config); + + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId)); + assertThat(e.status().getStatus(), equalTo(400)); + assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]")); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; this.destIndex = sourceIndex + "_results"; } - private static void indexTrainingData(String sourceIndex, int numRows) { + private static void indexData(String sourceIndex, + int numTrainingRows, int numNonTrainingRows, + List numericalFieldValues, List keywordFieldValues) { client().admin().indices().prepareCreate(sourceIndex) - .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword") + .addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword") .get(); BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - for (int i = 0; i < numRows; i++) { - Double field = NUMERICAL_FEATURE_VALUES.get(i % 3); - String value = DEPENDENT_VARIABLE_VALUES.get(i % 3); + for (int i = 0; i < numTrainingRows; i++) { + Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size()); + String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size()); IndexRequest indexRequest = new IndexRequest(sourceIndex) - .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value); + .source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue); + bulkRequestBuilder.add(indexRequest); + } + for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) { + Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size()); + + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .source(NUMERICAL_FIELD, numericalValue); bulkRequestBuilder.add(indexRequest); } BulkResponse bulkResponse = bulkRequestBuilder.get(); @@ -302,10 +298,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { classNames.add((String) topClass.get("class_name")); classProbabilities.add((Double) topClass.get("class_probability")); } - // Assert that all the class names come from the set of dependent variable values. - classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES)))); + // Assert that all the predicted class names come from the set of keyword field values. + classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES)))); // Assert that the first class listed in top classes is the same as the predicted class. - assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction"))); + assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction"))); // Assert that all the class probabilities lie within [0, 1] interval. classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)))); // Assert that the top classes are listed in the order of decreasing probabilities. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index a93efe67319..0546f94b533 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -13,6 +13,9 @@ import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Client; import org.elasticsearch.common.collect.ImmutableOpenMap; @@ -22,6 +25,10 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.Cardinality; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,6 +41,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; public class DataFrameDataExtractorFactory { @@ -172,13 +180,65 @@ public class DataFrameDataExtractorFactory { boolean isTaskRestarting, ActionListener listener) { AtomicInteger docValueFieldsLimitHolder = new AtomicInteger(); + AtomicReference extractedFieldsHolder = new AtomicReference<>(); - // Step 3. Extract fields (if possible) and notify listener + // Step 4. Check fields cardinality vs limits and notify listener + ActionListener checkCardinalityHandler = ActionListener.wrap( + searchResponse -> { + if (searchResponse != null) { + Aggregations aggs = searchResponse.getAggregations(); + if (aggs == null) { + listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities")); + return; + } + for (Map.Entry entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) { + String fieldName = entry.getKey(); + Long limit = entry.getValue(); + Cardinality cardinality = aggs.get(fieldName); + if (cardinality == null) { + listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities")); + return; + } + if (cardinality.getValue() > limit) { + listener.onFailure( + ExceptionsHelper.badRequestException( + "Field [{}] must have at most [{}] distinct values but there were at least [{}]", + fieldName, limit, cardinality.getValue())); + return; + } + } + } + listener.onResponse(extractedFieldsHolder.get()); + }, + listener::onFailure + ); + + // Step 3. Extract fields (if possible) ActionListener fieldCapabilitiesHandler = ActionListener.wrap( - fieldCapabilitiesResponse -> listener.onResponse( - new ExtractedFieldsDetector( - index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse) - .detect()), + fieldCapabilitiesResponse -> { + extractedFieldsHolder.set( + new ExtractedFieldsDetector( + index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse) + .detect()); + + Map fieldCardinalityLimits = config.getAnalysis().getFieldCardinalityLimits(); + if (fieldCardinalityLimits.isEmpty()) { + checkCardinalityHandler.onResponse(null); + } else { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0); + for (Map.Entry entry : fieldCardinalityLimits.entrySet()) { + String fieldName = entry.getKey(); + Long limit = entry.getValue(); + searchSourceBuilder.aggregation( + AggregationBuilders.cardinality(fieldName) + .field(fieldName) + .precisionThreshold(limit + 1)); + } + SearchRequest searchRequest = new SearchRequest(config.getSource().getIndex()).source(searchSourceBuilder); + ClientHelper.executeWithHeadersAsync( + config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE, searchRequest, checkCardinalityHandler); + } + }, listener::onFailure ); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml index 3e9c73a3fa8..9f08ed89b1f 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -69,7 +69,7 @@ body: mappings: properties: - long_field: { "type": "long" } + long_field: { type: "long" } - do: ml.put_data_frame_analytics: @@ -140,3 +140,52 @@ catch: /dest index \[non-empty-dest\] must be empty/ ml.start_data_frame_analytics: id: "start_given_empty_dest_index" + +--- +"Test start classification analysis when the dependent variable cardinality is too high": + - do: + indices.create: + index: index-with-dep-var-with-too-high-card + body: + mappings: + properties: + numeric_field: { type: "long" } + keyword_field: { type: "keyword" } + + - do: + index: + index: index-with-dep-var-with-too-high-card + body: { numeric_field: 1.0, keyword_field: "class_a" } + + - do: + index: + index: index-with-dep-var-with-too-high-card + body: { numeric_field: 2.0, keyword_field: "class_b" } + + - do: + index: + index: index-with-dep-var-with-too-high-card + body: { numeric_field: 3.0, keyword_field: "class_c" } + + - do: + indices.refresh: + index: index-with-dep-var-with-too-high-card + + - do: + ml.put_data_frame_analytics: + id: "too-high-card" + body: > + { + "source": { + "index": "index-with-dep-var-with-too-high-card" + }, + "dest": { + "index": "index-with-dep-var-with-too-high-card-dest" + }, + "analysis": { "classification": { "dependent_variable": "keyword_field" } } + } + + - do: + catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/ + ml.start_data_frame_analytics: + id: "too-high-card"