Require that the dependent variable column has at most 2 distinct values in classfication analysis. (#47858) (#47906)
This commit is contained in:
parent
a0d0866f59
commit
c62fe8c344
|
@ -152,6 +152,12 @@ public class Classification implements DataFrameAnalysis {
|
|||
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Long> getFieldCardinalityLimits() {
|
||||
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
|
||||
return Collections.singletonMap(dependentVariable, 2L);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsMissingValues() {
|
||||
return true;
|
||||
|
|
|
@ -28,6 +28,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
|||
*/
|
||||
List<RequiredField> getRequiredFields();
|
||||
|
||||
/**
|
||||
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
|
||||
*/
|
||||
Map<String, Long> getFieldCardinalityLimits();
|
||||
|
||||
/**
|
||||
* @return {@code true} if this analysis supports data frame rows with missing values
|
||||
*/
|
||||
|
|
|
@ -218,6 +218,11 @@ public class OutlierDetection implements DataFrameAnalysis {
|
|||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Long> getFieldCardinalityLimits() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsMissingValues() {
|
||||
return false;
|
||||
|
|
|
@ -139,6 +139,11 @@ public class Regression implements DataFrameAnalysis {
|
|||
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Long> getFieldCardinalityLimits() {
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsMissingValues() {
|
||||
return true;
|
||||
|
|
|
@ -13,6 +13,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
|
|||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||
|
||||
|
@ -65,4 +68,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
|||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ import java.util.Map;
|
|||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
|
||||
|
||||
|
@ -82,6 +84,10 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
|||
assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
public void testGetStateDocId() {
|
||||
OutlierDetection outlierDetection = createRandom();
|
||||
assertThat(outlierDetection.persistsState(), is(false));
|
||||
|
|
|
@ -14,6 +14,8 @@ import java.io.IOException;
|
|||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
|
@ -66,6 +68,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
|||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testFieldCardinalityLimitsIsNonNull() {
|
||||
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
public void testGetStateDocId() {
|
||||
Regression regression = createRandom();
|
||||
assertThat(regression.persistsState(), is(true));
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import com.google.common.collect.Ordering;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.get.GetResponse;
|
||||
|
@ -37,10 +38,10 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
|||
|
||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
||||
private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
|
||||
private static final String NUMERICAL_FIELD = "numerical-field";
|
||||
private static final String KEYWORD_FIELD = "keyword-field";
|
||||
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
|
||||
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
|
||||
|
||||
private String jobId;
|
||||
private String sourceIndex;
|
||||
|
@ -53,36 +54,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
|
||||
{ // Index 350 rows, 300 of them being training rows.
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
||||
.get();
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < 300; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = 300; i < 350; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
}
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -97,10 +71,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
}
|
||||
|
||||
|
@ -117,9 +91,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
initialize("classification_only_training_data_and_training_percent_is_100");
|
||||
indexTrainingData(sourceIndex, 300);
|
||||
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -133,8 +107,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||
|
@ -153,7 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||
initialize("classification_only_training_data_and_training_percent_is_50");
|
||||
indexTrainingData(sourceIndex, 300);
|
||||
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(
|
||||
|
@ -161,7 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
||||
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -176,8 +150,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
|
@ -205,7 +179,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
|
||||
initialize("classification_top_classes_requested");
|
||||
indexTrainingData(sourceIndex, 300);
|
||||
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||
|
||||
int numTopClasses = 2;
|
||||
DataFrameAnalyticsConfig config =
|
||||
|
@ -214,7 +188,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
sourceIndex,
|
||||
destIndex,
|
||||
null,
|
||||
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
||||
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
|
@ -229,8 +203,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
||||
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertTopClasses(resultsObject, numTopClasses);
|
||||
}
|
||||
|
||||
|
@ -245,25 +219,47 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Finished analysis");
|
||||
}
|
||||
|
||||
public void testDependentVariableCardinalityTooHighError() {
|
||||
initialize("cardinality_too_high");
|
||||
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId));
|
||||
assertThat(e.status().getStatus(), equalTo(400));
|
||||
assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
|
||||
}
|
||||
|
||||
private void initialize(String jobId) {
|
||||
this.jobId = jobId;
|
||||
this.sourceIndex = jobId + "_source_index";
|
||||
this.destIndex = sourceIndex + "_results";
|
||||
}
|
||||
|
||||
private static void indexTrainingData(String sourceIndex, int numRows) {
|
||||
private static void indexData(String sourceIndex,
|
||||
int numTrainingRows, int numNonTrainingRows,
|
||||
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
|
||||
client().admin().indices().prepareCreate(sourceIndex)
|
||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
||||
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
|
||||
.get();
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
for (int i = 0; i < numRows; i++) {
|
||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
||||
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
||||
for (int i = 0; i < numTrainingRows; i++) {
|
||||
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
||||
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
||||
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||
.source(NUMERICAL_FIELD, numericalValue);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
|
@ -302,10 +298,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
classNames.add((String) topClass.get("class_name"));
|
||||
classProbabilities.add((Double) topClass.get("class_probability"));
|
||||
}
|
||||
// Assert that all the class names come from the set of dependent variable values.
|
||||
classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
|
||||
// Assert that all the predicted class names come from the set of keyword field values.
|
||||
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
|
||||
// Assert that the first class listed in top classes is the same as the predicted class.
|
||||
assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
|
||||
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
|
||||
// Assert that all the class probabilities lie within [0, 1] interval.
|
||||
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||
// Assert that the top classes are listed in the order of decreasing probabilities.
|
||||
|
|
|
@ -13,6 +13,9 @@ import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
|
|||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
|
||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
|
||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.collect.ImmutableOpenMap;
|
||||
|
@ -22,6 +25,10 @@ import org.elasticsearch.index.IndexSettings;
|
|||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
@ -34,6 +41,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
public class DataFrameDataExtractorFactory {
|
||||
|
||||
|
@ -172,13 +180,65 @@ public class DataFrameDataExtractorFactory {
|
|||
boolean isTaskRestarting,
|
||||
ActionListener<ExtractedFields> listener) {
|
||||
AtomicInteger docValueFieldsLimitHolder = new AtomicInteger();
|
||||
AtomicReference<ExtractedFields> extractedFieldsHolder = new AtomicReference<>();
|
||||
|
||||
// Step 3. Extract fields (if possible) and notify listener
|
||||
// Step 4. Check fields cardinality vs limits and notify listener
|
||||
ActionListener<SearchResponse> checkCardinalityHandler = ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
if (searchResponse != null) {
|
||||
Aggregations aggs = searchResponse.getAggregations();
|
||||
if (aggs == null) {
|
||||
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
|
||||
return;
|
||||
}
|
||||
for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
|
||||
String fieldName = entry.getKey();
|
||||
Long limit = entry.getValue();
|
||||
Cardinality cardinality = aggs.get(fieldName);
|
||||
if (cardinality == null) {
|
||||
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
|
||||
return;
|
||||
}
|
||||
if (cardinality.getValue() > limit) {
|
||||
listener.onFailure(
|
||||
ExceptionsHelper.badRequestException(
|
||||
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
|
||||
fieldName, limit, cardinality.getValue()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
listener.onResponse(extractedFieldsHolder.get());
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
// Step 3. Extract fields (if possible)
|
||||
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
|
||||
fieldCapabilitiesResponse -> listener.onResponse(
|
||||
new ExtractedFieldsDetector(
|
||||
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
|
||||
.detect()),
|
||||
fieldCapabilitiesResponse -> {
|
||||
extractedFieldsHolder.set(
|
||||
new ExtractedFieldsDetector(
|
||||
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
|
||||
.detect());
|
||||
|
||||
Map<String, Long> fieldCardinalityLimits = config.getAnalysis().getFieldCardinalityLimits();
|
||||
if (fieldCardinalityLimits.isEmpty()) {
|
||||
checkCardinalityHandler.onResponse(null);
|
||||
} else {
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0);
|
||||
for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
|
||||
String fieldName = entry.getKey();
|
||||
Long limit = entry.getValue();
|
||||
searchSourceBuilder.aggregation(
|
||||
AggregationBuilders.cardinality(fieldName)
|
||||
.field(fieldName)
|
||||
.precisionThreshold(limit + 1));
|
||||
}
|
||||
SearchRequest searchRequest = new SearchRequest(config.getSource().getIndex()).source(searchSourceBuilder);
|
||||
ClientHelper.executeWithHeadersAsync(
|
||||
config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE, searchRequest, checkCardinalityHandler);
|
||||
}
|
||||
},
|
||||
listener::onFailure
|
||||
);
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@
|
|||
body:
|
||||
mappings:
|
||||
properties:
|
||||
long_field: { "type": "long" }
|
||||
long_field: { type: "long" }
|
||||
|
||||
- do:
|
||||
ml.put_data_frame_analytics:
|
||||
|
@ -140,3 +140,52 @@
|
|||
catch: /dest index \[non-empty-dest\] must be empty/
|
||||
ml.start_data_frame_analytics:
|
||||
id: "start_given_empty_dest_index"
|
||||
|
||||
---
|
||||
"Test start classification analysis when the dependent variable cardinality is too high":
|
||||
- do:
|
||||
indices.create:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
numeric_field: { type: "long" }
|
||||
keyword_field: { type: "keyword" }
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
body: { numeric_field: 1.0, keyword_field: "class_a" }
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
body: { numeric_field: 2.0, keyword_field: "class_b" }
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
body: { numeric_field: 3.0, keyword_field: "class_c" }
|
||||
|
||||
- do:
|
||||
indices.refresh:
|
||||
index: index-with-dep-var-with-too-high-card
|
||||
|
||||
- do:
|
||||
ml.put_data_frame_analytics:
|
||||
id: "too-high-card"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-with-dep-var-with-too-high-card"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-with-dep-var-with-too-high-card-dest"
|
||||
},
|
||||
"analysis": { "classification": { "dependent_variable": "keyword_field" } }
|
||||
}
|
||||
|
||||
- do:
|
||||
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
|
||||
ml.start_data_frame_analytics:
|
||||
id: "too-high-card"
|
||||
|
|
Loading…
Reference in New Issue