[ML] Use query in cardinality check (#49939) (#49984)

When checking the cardinality of a field, the query should be take into account. The user might know about some bad data in their index and want to filter down to the target_field values they care about.
This commit is contained in:
Benjamin Trent 2019-12-09 10:14:41 -05:00 committed by GitHub
parent 056c698540
commit 0b6ce9683c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 4 deletions

View File

@ -15,6 +15,8 @@ import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -228,7 +230,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction");
}
public void testDependentVariableCardinalityTooHighError() {
public void testDependentVariableCardinalityTooHighError() throws Exception {
initialize("cardinality_too_high");
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
// Index one more document with a class different than the two already used.
@ -245,6 +247,27 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
}
public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRange() throws Exception {
initialize("cardinality_too_high_with_query");
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
// Index one more document with a class different than the two already used.
client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex)
.source(KEYWORD_FIELD, "fox")
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE))
.actionGet();
QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(KEYWORD_FIELD, KEYWORD_FIELD_VALUES));
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD), query);
registerAnalytics(config);
putAnalytics(config);
// Should not throw
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
}
private void initialize(String jobId) {
this.jobId = jobId;
this.sourceIndex = jobId + "_source_index";

View File

@ -16,6 +16,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
@ -161,10 +163,16 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
}
protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
@Nullable String resultsField, DataFrameAnalysis analysis) {
@Nullable String resultsField, DataFrameAnalysis analysis) throws Exception {
return buildAnalytics(id, sourceIndex, destIndex, resultsField, analysis, QueryBuilders.matchAllQuery());
}
protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
@Nullable String resultsField, DataFrameAnalysis analysis,
QueryBuilder queryBuilder) throws Exception {
return new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null))
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, QueryProvider.fromParsedQuery(queryBuilder), null))
.setDest(new DataFrameAnalyticsDest(destIndex, resultsField))
.setAnalysis(analysis)
.build();

View File

@ -109,7 +109,7 @@ public class ExtractedFieldsDetectorFactory {
listener::onFailure
);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(config.getSource().getParsedQuery());
for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
String fieldName = entry.getKey();
Long limit = entry.getValue();