Require that the dependent variable column has at most 2 distinct values in classfication analysis. (#47858) (#47906)

This commit is contained in:
Przemysław Witek 2019-10-11 14:57:08 +02:00 committed by GitHub
parent a0d0866f59
commit c62fe8c344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 208 additions and 63 deletions

View File

@ -152,6 +152,12 @@ public class Classification implements DataFrameAnalysis {
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
}
@Override
public Map<String, Long> getFieldCardinalityLimits() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonMap(dependentVariable, 2L);
}
@Override
public boolean supportsMissingValues() {
return true;

View File

@ -28,6 +28,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
*/
List<RequiredField> getRequiredFields();
/**
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
*/
Map<String, Long> getFieldCardinalityLimits();
/**
* @return {@code true} if this analysis supports data frame rows with missing values
*/

View File

@ -218,6 +218,11 @@ public class OutlierDetection implements DataFrameAnalysis {
return Collections.emptyList();
}
@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
}
@Override
public boolean supportsMissingValues() {
return false;

View File

@ -139,6 +139,11 @@ public class Regression implements DataFrameAnalysis {
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
}
@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
}
@Override
public boolean supportsMissingValues() {
return true;

View File

@ -13,6 +13,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
@ -65,4 +68,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
}

View File

@ -15,6 +15,8 @@ import java.util.Map;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
@ -82,6 +84,10 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false));
}
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
public void testGetStateDocId() {
OutlierDetection outlierDetection = createRandom();
assertThat(outlierDetection.persistsState(), is(false));

View File

@ -14,6 +14,8 @@ import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
@ -66,6 +68,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}
public void testGetStateDocId() {
Regression regression = createRandom();
assertThat(regression.persistsState(), is(true));

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.integration;
import com.google.common.collect.Ordering;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
@ -37,10 +38,10 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature";
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
private static final String NUMERICAL_FIELD = "numerical-field";
private static final String KEYWORD_FIELD = "keyword-field";
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
private String jobId;
private String sourceIndex;
@ -53,36 +54,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set");
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
{ // Index 350 rows, 300 of them being training rows.
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
.get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < 300; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
bulkRequestBuilder.add(indexRequest);
}
for (int i = 300; i < 350; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
}
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);
@ -97,10 +71,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertThat(resultsObject.containsKey("top_classes"), is(false));
}
@ -117,9 +91,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_100");
indexTrainingData(sourceIndex, 300);
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);
@ -133,8 +107,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
assertThat(resultsObject.containsKey("top_classes"), is(false));
@ -153,7 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
initialize("classification_only_training_data_and_training_percent_is_50");
indexTrainingData(sourceIndex, 300);
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
DataFrameAnalyticsConfig config =
buildAnalytics(
@ -161,7 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
registerAnalytics(config);
putAnalytics(config);
@ -176,8 +150,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
@ -205,7 +179,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
initialize("classification_top_classes_requested");
indexTrainingData(sourceIndex, 300);
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
int numTopClasses = 2;
DataFrameAnalyticsConfig config =
@ -214,7 +188,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
sourceIndex,
destIndex,
null,
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
registerAnalytics(config);
putAnalytics(config);
@ -229,8 +203,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
assertTopClasses(resultsObject, numTopClasses);
}
@ -245,25 +219,47 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Finished analysis");
}
public void testDependentVariableCardinalityTooHighError() {
initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
registerAnalytics(config);
putAnalytics(config);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId));
assertThat(e.status().getStatus(), equalTo(400));
assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
}
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";
this.destIndex = sourceIndex + "_results";
}
private static void indexTrainingData(String sourceIndex, int numRows) {
private static void indexData(String sourceIndex,
int numTrainingRows, int numNonTrainingRows,
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
.get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
for (int i = 0; i < numRows; i++) {
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
for (int i = 0; i < numTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
bulkRequestBuilder.add(indexRequest);
}
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
IndexRequest indexRequest = new IndexRequest(sourceIndex)
.source(NUMERICAL_FIELD, numericalValue);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
@ -302,10 +298,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
classNames.add((String) topClass.get("class_name"));
classProbabilities.add((Double) topClass.get("class_probability"));
}
// Assert that all the class names come from the set of dependent variable values.
classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
// Assert that all the predicted class names come from the set of keyword field values.
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
// Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
// Assert that all the class probabilities lie within [0, 1] interval.
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
// Assert that the top classes are listed in the order of decreasing probabilities.

View File

@ -13,6 +13,9 @@ import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.collect.ImmutableOpenMap;
@ -22,6 +25,10 @@ import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -34,6 +41,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class DataFrameDataExtractorFactory {
@ -172,13 +180,65 @@ public class DataFrameDataExtractorFactory {
boolean isTaskRestarting,
ActionListener<ExtractedFields> listener) {
AtomicInteger docValueFieldsLimitHolder = new AtomicInteger();
AtomicReference<ExtractedFields> extractedFieldsHolder = new AtomicReference<>();
// Step 3. Extract fields (if possible) and notify listener
// Step 4. Check fields cardinality vs limits and notify listener
ActionListener<SearchResponse> checkCardinalityHandler = ActionListener.wrap(
searchResponse -> {
if (searchResponse != null) {
Aggregations aggs = searchResponse.getAggregations();
if (aggs == null) {
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
return;
}
for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
String fieldName = entry.getKey();
Long limit = entry.getValue();
Cardinality cardinality = aggs.get(fieldName);
if (cardinality == null) {
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
return;
}
if (cardinality.getValue() > limit) {
listener.onFailure(
ExceptionsHelper.badRequestException(
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
fieldName, limit, cardinality.getValue()));
return;
}
}
}
listener.onResponse(extractedFieldsHolder.get());
},
listener::onFailure
);
// Step 3. Extract fields (if possible)
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
fieldCapabilitiesResponse -> listener.onResponse(
new ExtractedFieldsDetector(
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
.detect()),
fieldCapabilitiesResponse -> {
extractedFieldsHolder.set(
new ExtractedFieldsDetector(
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
.detect());
Map<String, Long> fieldCardinalityLimits = config.getAnalysis().getFieldCardinalityLimits();
if (fieldCardinalityLimits.isEmpty()) {
checkCardinalityHandler.onResponse(null);
} else {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0);
for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
String fieldName = entry.getKey();
Long limit = entry.getValue();
searchSourceBuilder.aggregation(
AggregationBuilders.cardinality(fieldName)
.field(fieldName)
.precisionThreshold(limit + 1));
}
SearchRequest searchRequest = new SearchRequest(config.getSource().getIndex()).source(searchSourceBuilder);
ClientHelper.executeWithHeadersAsync(
config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE, searchRequest, checkCardinalityHandler);
}
},
listener::onFailure
);

View File

@ -69,7 +69,7 @@
body:
mappings:
properties:
long_field: { "type": "long" }
long_field: { type: "long" }
- do:
ml.put_data_frame_analytics:
@ -140,3 +140,52 @@
catch: /dest index \[non-empty-dest\] must be empty/
ml.start_data_frame_analytics:
id: "start_given_empty_dest_index"
---
"Test start classification analysis when the dependent variable cardinality is too high":
- do:
indices.create:
index: index-with-dep-var-with-too-high-card
body:
mappings:
properties:
numeric_field: { type: "long" }
keyword_field: { type: "keyword" }
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 1.0, keyword_field: "class_a" }
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 2.0, keyword_field: "class_b" }
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 3.0, keyword_field: "class_c" }
- do:
indices.refresh:
index: index-with-dep-var-with-too-high-card
- do:
ml.put_data_frame_analytics:
id: "too-high-card"
body: >
{
"source": {
"index": "index-with-dep-var-with-too-high-card"
},
"dest": {
"index": "index-with-dep-var-with-too-high-card-dest"
},
"analysis": { "classification": { "dependent_variable": "keyword_field" } }
}
- do:
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
ml.start_data_frame_analytics:
id: "too-high-card"