Require that the dependent variable column has at most 2 distinct values in classfication analysis. (#47858) (#47906)
This commit is contained in:
parent
a0d0866f59
commit
c62fe8c344
|
@ -152,6 +152,12 @@ public class Classification implements DataFrameAnalysis {
|
||||||
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
|
return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> getFieldCardinalityLimits() {
|
||||||
|
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
|
||||||
|
return Collections.singletonMap(dependentVariable, 2L);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean supportsMissingValues() {
|
public boolean supportsMissingValues() {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -28,6 +28,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
|
||||||
*/
|
*/
|
||||||
List<RequiredField> getRequiredFields();
|
List<RequiredField> getRequiredFields();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
|
||||||
|
*/
|
||||||
|
Map<String, Long> getFieldCardinalityLimits();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return {@code true} if this analysis supports data frame rows with missing values
|
* @return {@code true} if this analysis supports data frame rows with missing values
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -218,6 +218,11 @@ public class OutlierDetection implements DataFrameAnalysis {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> getFieldCardinalityLimits() {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean supportsMissingValues() {
|
public boolean supportsMissingValues() {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -139,6 +139,11 @@ public class Regression implements DataFrameAnalysis {
|
||||||
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
|
return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Long> getFieldCardinalityLimits() {
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean supportsMissingValues() {
|
public boolean supportsMissingValues() {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -13,6 +13,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
|
||||||
|
|
||||||
|
@ -65,4 +68,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||||
|
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFieldCardinalityLimitsIsNonNull() {
|
||||||
|
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@ import java.util.Map;
|
||||||
import static org.hamcrest.Matchers.closeTo;
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
|
public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
|
||||||
|
|
||||||
|
@ -82,6 +84,10 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
|
||||||
assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false));
|
assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFieldCardinalityLimitsIsNonNull() {
|
||||||
|
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||||
|
}
|
||||||
|
|
||||||
public void testGetStateDocId() {
|
public void testGetStateDocId() {
|
||||||
OutlierDetection outlierDetection = createRandom();
|
OutlierDetection outlierDetection = createRandom();
|
||||||
assertThat(outlierDetection.persistsState(), is(false));
|
assertThat(outlierDetection.persistsState(), is(false));
|
||||||
|
|
|
@ -14,6 +14,8 @@ import java.io.IOException;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
|
|
||||||
|
@ -66,6 +68,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFieldCardinalityLimitsIsNonNull() {
|
||||||
|
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
|
||||||
|
}
|
||||||
|
|
||||||
public void testGetStateDocId() {
|
public void testGetStateDocId() {
|
||||||
Regression regression = createRandom();
|
Regression regression = createRandom();
|
||||||
assertThat(regression.persistsState(), is(true));
|
assertThat(regression.persistsState(), is(true));
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import com.google.common.collect.Ordering;
|
import com.google.common.collect.Ordering;
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||||
import org.elasticsearch.action.bulk.BulkResponse;
|
import org.elasticsearch.action.bulk.BulkResponse;
|
||||||
import org.elasticsearch.action.get.GetResponse;
|
import org.elasticsearch.action.get.GetResponse;
|
||||||
|
@ -37,10 +38,10 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||||
|
|
||||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
private static final String NUMERICAL_FIELD = "numerical-field";
|
||||||
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
private static final String KEYWORD_FIELD = "keyword-field";
|
||||||
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
|
||||||
private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
|
private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
|
||||||
|
|
||||||
private String jobId;
|
private String jobId;
|
||||||
private String sourceIndex;
|
private String sourceIndex;
|
||||||
|
@ -53,36 +54,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||||
|
indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||||
|
|
||||||
{ // Index 350 rows, 300 of them being training rows.
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||||
client().admin().indices().prepareCreate(sourceIndex)
|
|
||||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
|
||||||
.get();
|
|
||||||
|
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
||||||
for (int i = 0; i < 300; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
for (int i = 300; i < 350; i++) {
|
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field);
|
|
||||||
bulkRequestBuilder.add(indexRequest);
|
|
||||||
}
|
|
||||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
||||||
if (bulkResponse.hasFailures()) {
|
|
||||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -97,10 +71,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,9 +91,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||||
initialize("classification_only_training_data_and_training_percent_is_100");
|
initialize("classification_only_training_data_and_training_percent_is_100");
|
||||||
indexTrainingData(sourceIndex, 300);
|
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -133,8 +107,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(true));
|
assertThat(resultsObject.get("is_training"), is(true));
|
||||||
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
||||||
|
@ -153,7 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||||
initialize("classification_only_training_data_and_training_percent_is_50");
|
initialize("classification_only_training_data_and_training_percent_is_50");
|
||||||
indexTrainingData(sourceIndex, 300);
|
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||||
|
|
||||||
DataFrameAnalyticsConfig config =
|
DataFrameAnalyticsConfig config =
|
||||||
buildAnalytics(
|
buildAnalytics(
|
||||||
|
@ -161,7 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -176,8 +150,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
// Let's just assert there's both training and non-training results
|
// Let's just assert there's both training and non-training results
|
||||||
|
@ -205,7 +179,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
|
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
|
||||||
initialize("classification_top_classes_requested");
|
initialize("classification_top_classes_requested");
|
||||||
indexTrainingData(sourceIndex, 300);
|
indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
||||||
|
|
||||||
int numTopClasses = 2;
|
int numTopClasses = 2;
|
||||||
DataFrameAnalyticsConfig config =
|
DataFrameAnalyticsConfig config =
|
||||||
|
@ -214,7 +188,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
sourceIndex,
|
sourceIndex,
|
||||||
destIndex,
|
destIndex,
|
||||||
null,
|
null,
|
||||||
new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
||||||
registerAnalytics(config);
|
registerAnalytics(config);
|
||||||
putAnalytics(config);
|
putAnalytics(config);
|
||||||
|
|
||||||
|
@ -229,8 +203,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||||
|
|
||||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
||||||
assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
assertTopClasses(resultsObject, numTopClasses);
|
assertTopClasses(resultsObject, numTopClasses);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,25 +219,47 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Finished analysis");
|
"Finished analysis");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testDependentVariableCardinalityTooHighError() {
|
||||||
|
initialize("cardinality_too_high");
|
||||||
|
indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
|
||||||
|
|
||||||
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||||
|
registerAnalytics(config);
|
||||||
|
putAnalytics(config);
|
||||||
|
|
||||||
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId));
|
||||||
|
assertThat(e.status().getStatus(), equalTo(400));
|
||||||
|
assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
|
||||||
|
}
|
||||||
|
|
||||||
private void initialize(String jobId) {
|
private void initialize(String jobId) {
|
||||||
this.jobId = jobId;
|
this.jobId = jobId;
|
||||||
this.sourceIndex = jobId + "_source_index";
|
this.sourceIndex = jobId + "_source_index";
|
||||||
this.destIndex = sourceIndex + "_results";
|
this.destIndex = sourceIndex + "_results";
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void indexTrainingData(String sourceIndex, int numRows) {
|
private static void indexData(String sourceIndex,
|
||||||
|
int numTrainingRows, int numNonTrainingRows,
|
||||||
|
List<Double> numericalFieldValues, List<String> keywordFieldValues) {
|
||||||
client().admin().indices().prepareCreate(sourceIndex)
|
client().admin().indices().prepareCreate(sourceIndex)
|
||||||
.addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
.addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
|
||||||
.get();
|
.get();
|
||||||
|
|
||||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||||
for (int i = 0; i < numRows; i++) {
|
for (int i = 0; i < numTrainingRows; i++) {
|
||||||
Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
||||||
String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
|
||||||
|
|
||||||
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||||
.source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
.source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
|
||||||
|
bulkRequestBuilder.add(indexRequest);
|
||||||
|
}
|
||||||
|
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
||||||
|
Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
||||||
|
|
||||||
|
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
||||||
|
.source(NUMERICAL_FIELD, numericalValue);
|
||||||
bulkRequestBuilder.add(indexRequest);
|
bulkRequestBuilder.add(indexRequest);
|
||||||
}
|
}
|
||||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||||
|
@ -302,10 +298,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
classNames.add((String) topClass.get("class_name"));
|
classNames.add((String) topClass.get("class_name"));
|
||||||
classProbabilities.add((Double) topClass.get("class_probability"));
|
classProbabilities.add((Double) topClass.get("class_probability"));
|
||||||
}
|
}
|
||||||
// Assert that all the class names come from the set of dependent variable values.
|
// Assert that all the predicted class names come from the set of keyword field values.
|
||||||
classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
|
classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
|
||||||
// Assert that the first class listed in top classes is the same as the predicted class.
|
// Assert that the first class listed in top classes is the same as the predicted class.
|
||||||
assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
|
assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
|
||||||
// Assert that all the class probabilities lie within [0, 1] interval.
|
// Assert that all the class probabilities lie within [0, 1] interval.
|
||||||
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
||||||
// Assert that the top classes are listed in the order of decreasing probabilities.
|
// Assert that the top classes are listed in the order of decreasing probabilities.
|
||||||
|
|
|
@ -13,6 +13,9 @@ import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
|
||||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
|
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
|
||||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
|
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
|
||||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
||||||
|
import org.elasticsearch.action.search.SearchAction;
|
||||||
|
import org.elasticsearch.action.search.SearchRequest;
|
||||||
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.IndicesOptions;
|
import org.elasticsearch.action.support.IndicesOptions;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.common.collect.ImmutableOpenMap;
|
import org.elasticsearch.common.collect.ImmutableOpenMap;
|
||||||
|
@ -22,6 +25,10 @@ import org.elasticsearch.index.IndexSettings;
|
||||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilder;
|
import org.elasticsearch.index.query.QueryBuilder;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
|
import org.elasticsearch.search.aggregations.metrics.Cardinality;
|
||||||
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.xpack.core.ClientHelper;
|
import org.elasticsearch.xpack.core.ClientHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
@ -34,6 +41,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
public class DataFrameDataExtractorFactory {
|
public class DataFrameDataExtractorFactory {
|
||||||
|
|
||||||
|
@ -172,13 +180,65 @@ public class DataFrameDataExtractorFactory {
|
||||||
boolean isTaskRestarting,
|
boolean isTaskRestarting,
|
||||||
ActionListener<ExtractedFields> listener) {
|
ActionListener<ExtractedFields> listener) {
|
||||||
AtomicInteger docValueFieldsLimitHolder = new AtomicInteger();
|
AtomicInteger docValueFieldsLimitHolder = new AtomicInteger();
|
||||||
|
AtomicReference<ExtractedFields> extractedFieldsHolder = new AtomicReference<>();
|
||||||
|
|
||||||
// Step 3. Extract fields (if possible) and notify listener
|
// Step 4. Check fields cardinality vs limits and notify listener
|
||||||
|
ActionListener<SearchResponse> checkCardinalityHandler = ActionListener.wrap(
|
||||||
|
searchResponse -> {
|
||||||
|
if (searchResponse != null) {
|
||||||
|
Aggregations aggs = searchResponse.getAggregations();
|
||||||
|
if (aggs == null) {
|
||||||
|
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
|
||||||
|
String fieldName = entry.getKey();
|
||||||
|
Long limit = entry.getValue();
|
||||||
|
Cardinality cardinality = aggs.get(fieldName);
|
||||||
|
if (cardinality == null) {
|
||||||
|
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (cardinality.getValue() > limit) {
|
||||||
|
listener.onFailure(
|
||||||
|
ExceptionsHelper.badRequestException(
|
||||||
|
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
|
||||||
|
fieldName, limit, cardinality.getValue()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
listener.onResponse(extractedFieldsHolder.get());
|
||||||
|
},
|
||||||
|
listener::onFailure
|
||||||
|
);
|
||||||
|
|
||||||
|
// Step 3. Extract fields (if possible)
|
||||||
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
|
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
|
||||||
fieldCapabilitiesResponse -> listener.onResponse(
|
fieldCapabilitiesResponse -> {
|
||||||
|
extractedFieldsHolder.set(
|
||||||
new ExtractedFieldsDetector(
|
new ExtractedFieldsDetector(
|
||||||
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
|
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
|
||||||
.detect()),
|
.detect());
|
||||||
|
|
||||||
|
Map<String, Long> fieldCardinalityLimits = config.getAnalysis().getFieldCardinalityLimits();
|
||||||
|
if (fieldCardinalityLimits.isEmpty()) {
|
||||||
|
checkCardinalityHandler.onResponse(null);
|
||||||
|
} else {
|
||||||
|
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0);
|
||||||
|
for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
|
||||||
|
String fieldName = entry.getKey();
|
||||||
|
Long limit = entry.getValue();
|
||||||
|
searchSourceBuilder.aggregation(
|
||||||
|
AggregationBuilders.cardinality(fieldName)
|
||||||
|
.field(fieldName)
|
||||||
|
.precisionThreshold(limit + 1));
|
||||||
|
}
|
||||||
|
SearchRequest searchRequest = new SearchRequest(config.getSource().getIndex()).source(searchSourceBuilder);
|
||||||
|
ClientHelper.executeWithHeadersAsync(
|
||||||
|
config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE, searchRequest, checkCardinalityHandler);
|
||||||
|
}
|
||||||
|
},
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@
|
||||||
body:
|
body:
|
||||||
mappings:
|
mappings:
|
||||||
properties:
|
properties:
|
||||||
long_field: { "type": "long" }
|
long_field: { type: "long" }
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.put_data_frame_analytics:
|
ml.put_data_frame_analytics:
|
||||||
|
@ -140,3 +140,52 @@
|
||||||
catch: /dest index \[non-empty-dest\] must be empty/
|
catch: /dest index \[non-empty-dest\] must be empty/
|
||||||
ml.start_data_frame_analytics:
|
ml.start_data_frame_analytics:
|
||||||
id: "start_given_empty_dest_index"
|
id: "start_given_empty_dest_index"
|
||||||
|
|
||||||
|
---
|
||||||
|
"Test start classification analysis when the dependent variable cardinality is too high":
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: index-with-dep-var-with-too-high-card
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
numeric_field: { type: "long" }
|
||||||
|
keyword_field: { type: "keyword" }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: index-with-dep-var-with-too-high-card
|
||||||
|
body: { numeric_field: 1.0, keyword_field: "class_a" }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: index-with-dep-var-with-too-high-card
|
||||||
|
body: { numeric_field: 2.0, keyword_field: "class_b" }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: index-with-dep-var-with-too-high-card
|
||||||
|
body: { numeric_field: 3.0, keyword_field: "class_c" }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.refresh:
|
||||||
|
index: index-with-dep-var-with-too-high-card
|
||||||
|
|
||||||
|
- do:
|
||||||
|
ml.put_data_frame_analytics:
|
||||||
|
id: "too-high-card"
|
||||||
|
body: >
|
||||||
|
{
|
||||||
|
"source": {
|
||||||
|
"index": "index-with-dep-var-with-too-high-card"
|
||||||
|
},
|
||||||
|
"dest": {
|
||||||
|
"index": "index-with-dep-var-with-too-high-card-dest"
|
||||||
|
},
|
||||||
|
"analysis": { "classification": { "dependent_variable": "keyword_field" } }
|
||||||
|
}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
|
||||||
|
ml.start_data_frame_analytics:
|
||||||
|
id: "too-high-card"
|
||||||
|
|
Loading…
Reference in New Issue