[7.x][ML] Validate classification dependent_variable cardinality is at lea… (#51232) (#51309)

Data frame analytics classification currently only supports 2 classes for the
dependent variable. We were checking that the field's cardinality is not higher
than 2 but we should also check it is not less than that as otherwise the process
fails.

Backport of #51232
This commit is contained in:
Dimitris Athanasiou 2020-01-22 16:51:16 +02:00 committed by GitHub
parent 2a73e849d6
commit 59687a9384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 197 additions and 70 deletions

View File

@ -245,9 +245,9 @@ public class Classification implements DataFrameAnalysis {
}
@Override
public Map<String, Long> getFieldCardinalityLimits() {
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
return Collections.singletonMap(dependentVariable, 2L);
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
}
@SuppressWarnings("unchecked")

View File

@ -37,9 +37,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
List<RequiredField> getRequiredFields();
/**
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
* @return {@link List} containing cardinality constraints for the selected (analysis-specific) fields
*/
Map<String, Long> getFieldCardinalityLimits();
List<FieldCardinalityConstraint> getFieldCardinalityConstraints();
/**
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.

View File

@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Objects;
/**
* Allows checking a field's cardinality against given lower and upper bounds
*/
public class FieldCardinalityConstraint {
private final String field;
private final long lowerBound;
private final long upperBound;
public static FieldCardinalityConstraint between(String field, long lowerBound, long upperBound) {
return new FieldCardinalityConstraint(field, lowerBound, upperBound);
}
private FieldCardinalityConstraint(String field, long lowerBound, long upperBound) {
this.field = Objects.requireNonNull(field);
this.lowerBound = lowerBound;
this.upperBound = upperBound;
}
public String getField() {
return field;
}
public long getLowerBound() {
return lowerBound;
}
public long getUpperBound() {
return upperBound;
}
public void check(long fieldCardinality) {
if (fieldCardinality < lowerBound) {
throw ExceptionsHelper.badRequestException(
"Field [{}] must have at least [{}] distinct values but there were [{}]",
field, lowerBound, fieldCardinality);
}
if (fieldCardinality > upperBound) {
throw ExceptionsHelper.badRequestException(
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
field, upperBound, fieldCardinality);
}
}
}

View File

@ -225,8 +225,8 @@ public class OutlierDetection implements DataFrameAnalysis {
}
@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
return Collections.emptyList();
}
@Override

View File

@ -182,8 +182,8 @@ public class Regression implements DataFrameAnalysis {
}
@Override
public Map<String, Long> getFieldCardinalityLimits() {
return Collections.emptyMap();
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
return Collections.emptyList();
}
@Override

View File

@ -22,6 +22,7 @@ import org.hamcrest.Matchers;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -169,7 +170,13 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
}
public void testFieldCardinalityLimitsIsNonEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
Classification classification = createTestInstance();
List<FieldCardinalityConstraint> constraints = classification.getFieldCardinalityConstraints();
assertThat(constraints.size(), equalTo(1));
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
}
public void testGetExplicitlyMappedFields() {

View File

@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.equalTo;
public class FieldCardinalityConstraintTests extends ESTestCase {
public void testBetween_GivenWithinLimits() {
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
constraint.check(3);
constraint.check(4);
constraint.check(5);
constraint.check(6);
}
public void testBetween_GivenLessThanLowerBound() {
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(2L));
assertThat(e.getMessage(), equalTo("Field [foo] must have at least [3] distinct values but there were [2]"));
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
}
public void testBetween_GivenGreaterThanUpperBound() {
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(7L));
assertThat(e.getMessage(), equalTo("Field [foo] must have at most [6] distinct values but there were at least [7]"));
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
}
}

View File

@ -89,7 +89,7 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
}
public void testFieldCardinalityLimitsIsEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
}
public void testGetExplicitlyMappedFields() {

View File

@ -19,7 +19,6 @@ import java.io.IOException;
import java.util.Collections;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
@ -107,7 +106,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
}
public void testFieldCardinalityLimitsIsEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
}
public void testGetExplicitlyMappedFields() {

View File

@ -43,7 +43,11 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
String sourceIndex = "test-source-query-is-applied";
client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword")
.addMapping("_doc",
"numeric_1", "type=double",
"numeric_2", "type=float",
"categorical", "type=keyword",
"filtered_field", "type=keyword")
.get();
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
@ -51,9 +55,11 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
for (int i = 0; i < 30; i++) {
IndexRequest indexRequest = new IndexRequest(sourceIndex);
// We insert one odd value out of 5 for one feature
indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal");
indexRequest.source(
"numeric_1", 1.0,
"numeric_2", 2.0,
"categorical", i % 2 == 0 ? "class_1" : "class_2",
"filtered_field", i < 2 ? "bingo" : "rest"); // We tag bingo on the first two docs to ensure we have 2 classes
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
@ -66,7 +72,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one")),
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("filtered_field", "bingo")),
null))
.setAnalysis(new Classification("categorical"))
.buildForExplain();

View File

@ -15,6 +15,9 @@ import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest;
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.get.GetIndexResponse;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
@ -42,6 +45,7 @@ import java.util.Objects;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeWithHeadersAsync;
public class DataFrameAnalyticsManager {
@ -158,7 +162,7 @@ public class DataFrameAnalyticsManager {
// Reindexing is complete; start analytics
ActionListener<BulkByScrollResponse> reindexCompletedListener = ActionListener.wrap(
refreshResponse -> {
reindexResponse -> {
if (task.isStopping()) {
LOGGER.debug("[{}] Stopping before starting analytics process", config.getId());
return;
@ -177,6 +181,7 @@ public class DataFrameAnalyticsManager {
ActionListener<CreateIndexResponse> copyIndexCreatedListener = ActionListener.wrap(
createIndexResponse -> {
ReindexRequest reindexRequest = new ReindexRequest();
reindexRequest.setRefresh(true);
reindexRequest.setSourceIndices(config.getSource().getIndex());
reindexRequest.setSourceQuery(config.getSource().getParsedQuery());
reindexRequest.getSearchRequest().source().fetchSource(config.getSource().getSourceFiltering());
@ -224,9 +229,6 @@ public class DataFrameAnalyticsManager {
}
private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();
// Update state to ANALYZING and start process
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
dataExtractorFactory -> {
@ -246,10 +248,23 @@ public class DataFrameAnalyticsManager {
error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
);
// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
refreshResponse -> {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();
// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
},
dataExtractorFactoryListener::onFailure
);
// First we need to refresh the dest index to ensure data is searchable in case the job
// was stopped after reindexing was complete but before the index was refreshed.
executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, RefreshAction.INSTANCE,
new RefreshRequest(config.getDest().getIndex()), refreshListener);
}
public void stop(DataFrameAnalyticsTask task) {

View File

@ -19,6 +19,7 @@ import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
@ -284,15 +285,8 @@ public class ExtractedFieldsDetector {
}
private void checkFieldsWithCardinalityLimit() {
for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
String fieldName = entry.getKey();
long limit = entry.getValue();
long cardinality = fieldCardinalities.get(fieldName);
if (cardinality > limit) {
throw ExceptionsHelper.badRequestException(
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
fieldName, limit, cardinality);
}
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
constraint.check(fieldCardinalities.get(constraint.getField()));
}
}

View File

@ -28,11 +28,13 @@ import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
@ -72,11 +74,11 @@ public class ExtractedFieldsDetectorFactory {
listener::onFailure
);
// Step 3. Get cardinalities for fields with limits
// Step 3. Get cardinalities for fields with constraints
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
fieldCapabilitiesResponse -> {
fieldCapsResponseHolder.set(fieldCapabilitiesResponse);
getCardinalitiesForFieldsWithLimit(index, config, fieldCardinalitiesHandler);
getCardinalitiesForFieldsWithConstraints(index, config, fieldCardinalitiesHandler);
},
listener::onFailure
);
@ -94,10 +96,10 @@ public class ExtractedFieldsDetectorFactory {
getDocValueFieldsLimit(index, docValueFieldsLimitListener);
}
private void getCardinalitiesForFieldsWithLimit(String[] index, DataFrameAnalyticsConfig config,
ActionListener<Map<String, Long>> listener) {
Map<String, Long> fieldCardinalityLimits = config.getAnalysis().getFieldCardinalityLimits();
if (fieldCardinalityLimits.isEmpty()) {
private void getCardinalitiesForFieldsWithConstraints(String[] index, DataFrameAnalyticsConfig config,
ActionListener<Map<String, Long>> listener) {
List<FieldCardinalityConstraint> fieldCardinalityConstraints = config.getAnalysis().getFieldCardinalityConstraints();
if (fieldCardinalityConstraints.isEmpty()) {
listener.onResponse(Collections.emptyMap());
return;
}
@ -108,13 +110,11 @@ public class ExtractedFieldsDetectorFactory {
);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(config.getSource().getParsedQuery());
for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
String fieldName = entry.getKey();
Long limit = entry.getValue();
for (FieldCardinalityConstraint constraint : fieldCardinalityConstraints) {
searchSourceBuilder.aggregation(
AggregationBuilders.cardinality(fieldName)
.field(fieldName)
.precisionThreshold(limit + 1));
AggregationBuilders.cardinality(constraint.getField())
.field(constraint.getField())
.precisionThreshold(constraint.getUpperBound() + 1));
}
SearchRequest searchRequest = new SearchRequest(index).source(searchSourceBuilder);
ClientHelper.executeWithHeadersAsync(
@ -129,14 +129,14 @@ public class ExtractedFieldsDetectorFactory {
return;
}
Map<String, Long> fieldCardinalities = new HashMap<>(config.getAnalysis().getFieldCardinalityLimits().size());
for (String field : config.getAnalysis().getFieldCardinalityLimits().keySet()) {
Cardinality cardinality = aggs.get(field);
Map<String, Long> fieldCardinalities = new HashMap<>(config.getAnalysis().getFieldCardinalityConstraints().size());
for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) {
Cardinality cardinality = aggs.get(constraint.getField());
if (cardinality == null) {
listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
return;
}
fieldCardinalities.put(field, cardinality.getValue());
fieldCardinalities.put(constraint.getField(), cardinality.getValue());
}
listener.onResponse(fieldCardinalities);
}

View File

@ -109,9 +109,6 @@ public class AnalyticsProcessManager {
}
}
// Refresh the dest index to ensure data is searchable
refreshDest(config);
// Fetch existing model state (if any)
BytesReference state = getModelState(config);

View File

@ -560,8 +560,13 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.addAggregatableField("some_integer", "integer")
.build();
Map<String, Long> fieldCardinalities = new HashMap<>(2);
fieldCardinalities.put("some_boolean", 2L);
fieldCardinalities.put("some_integer", 2L);
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, config, 100, fieldCapabilities, config.getAnalysis().getFieldCardinalityLimits());
SOURCE_INDEX, config, 100, fieldCapabilities, fieldCardinalities);
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
List<ExtractedField> allFields = fieldExtraction.v1().getAllFields();

View File

@ -142,7 +142,7 @@
id: "start_given_empty_dest_index"
---
"Test start classification analysis when the dependent variable cardinality is too high":
"Test start classification analysis when the dependent variable cardinality is too low or too high":
- do:
indices.create:
index: index-with-dep-var-with-too-high-card
@ -152,11 +152,34 @@
numeric_field: { type: "long" }
keyword_field: { type: "keyword" }
- do:
ml.put_data_frame_analytics:
id: "classification-cardinality-limits"
body: >
{
"source": {
"index": "index-with-dep-var-with-too-high-card"
},
"dest": {
"index": "index-with-dep-var-with-too-high-card-dest"
},
"analysis": { "classification": { "dependent_variable": "keyword_field" } }
}
- do:
index:
index: index-with-dep-var-with-too-high-card
body: { numeric_field: 1.0, keyword_field: "class_a" }
- do:
indices.refresh:
index: index-with-dep-var-with-too-high-card
- do:
catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/
ml.start_data_frame_analytics:
id: "classification-cardinality-limits"
- do:
index:
index: index-with-dep-var-with-too-high-card
@ -171,21 +194,7 @@
indices.refresh:
index: index-with-dep-var-with-too-high-card
- do:
ml.put_data_frame_analytics:
id: "too-high-card"
body: >
{
"source": {
"index": "index-with-dep-var-with-too-high-card"
},
"dest": {
"index": "index-with-dep-var-with-too-high-card-dest"
},
"analysis": { "classification": { "dependent_variable": "keyword_field" } }
}
- do:
catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
ml.start_data_frame_analytics:
id: "too-high-card"
id: "classification-cardinality-limits"