[7.x][ML] Deduplicate multi-fields for data frame analytics (#48799) (#48806)

In the case multi-fields exist in the source index, we pick
all variants of them in our extracted fields detection for
data frame analytics. This means we may have multiple instances
of the same feature. The worse consequence of this is when the
dependent variable (for regression or classification) is also
duplicated which means we train a model on the dependent variable
itself.

Now that #48770 is merged, this commit is adding logic to
only select one variant of multi-fields.

Closes #48756

Backport of #48799
This commit is contained in:
Dimitris Athanasiou 2019-11-01 16:53:05 +02:00 committed by GitHub
parent fd4ae697b8
commit f2d4c94a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 209 additions and 3 deletions

View File

@ -31,6 +31,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -238,7 +239,9 @@ public class ExtractedFieldsDetector {
// We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource);
if (preferSource) {
extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " +
@ -250,9 +253,59 @@ public class ExtractedFieldsDetector {
return extractedFields;
}
private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource) {
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
.collect(Collectors.toSet());
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
for (ExtractedField currentField : extractedFields.getAllFields()) {
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
ExtractedField existingField = nameOrParentToField.putIfAbsent(nameOrParent, currentField);
if (existingField != null) {
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
nameOrParentToField.put(nameOrParent, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField));
}
}
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
}
private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields,
ExtractedField parent, ExtractedField multiField) {
// Check requirements first
if (requiredFields.contains(parent.getName())) {
return parent;
}
if (requiredFields.contains(multiField.getName())) {
return multiField;
}
// If both are multi-fields it means there are several. In this case parent is the previous multi-field
// we selected. We'll just keep that.
if (parent.isMultiField() && multiField.isMultiField()) {
return parent;
}
// If we prefer source only the parent may support it. If it does we pick it immediately.
if (preferSource && parent.supportsFromSource()) {
return parent;
}
// If any of the two is a doc_value field let's prefer it as it'd support aggregations.
// We check the parent first as it'd be a shorter field name.
if (parent.getMethod() == ExtractedField.Method.DOC_VALUE) {
return parent;
}
if (multiField.getMethod() == ExtractedField.Method.DOC_VALUE) {
return multiField;
}
// None is aggregatable. Let's pick the parent for its shorter name.
return parent;
}
private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getDocValueFields()) {
for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
}
return new ExtractedFields(adjusted);

View File

@ -534,6 +534,151 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
}
public void testDetect_GivenMultiFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("a_float", "float")
.addNonAggregatableField("text_without_keyword", "text")
.addNonAggregatableField("text_1", "text")
.addAggregatableField("text_1.keyword", "keyword")
.addNonAggregatableField("text_2", "text")
.addAggregatableField("text_2.keyword", "keyword")
.addAggregatableField("keyword_1", "keyword")
.addNonAggregatableField("keyword_1.text", "text")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("a_float"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(5));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword"));
}
public void testDetect_GivenMultiFieldAndParentIsRequired() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("field_1"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}
public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword", "field_2"));
}
public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword_3", "keyword")
.addAggregatableField("field_1.keyword_2", "keyword")
.addAggregatableField("field_1.keyword_1", "keyword")
.addAggregatableField("field_2", "float")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2"));
}
public void testDetect_GivenMultiFields_OverDocValueLimit() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword_1", "keyword")
.addAggregatableField("field_2", "float")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 0, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}
public void testDetect_GivenParentAndMultiFieldBothAggregatable() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2.keyword", "float")
.addAggregatableField("field_2.double", "double")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2.double"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2.double"));
}
public void testDetect_GivenParentAndMultiFieldNoneAggregatable() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addNonAggregatableField("field_1.text", "text")
.addAggregatableField("field_2", "float")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}
public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();
FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]);
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
return buildOutlierDetectionConfig(null);
}
@ -576,9 +721,17 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();
private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) {
return addField(field, true, types);
}
private MockFieldCapsResponseBuilder addNonAggregatableField(String field, String... types) {
return addField(field, false, types);
}
private MockFieldCapsResponseBuilder addField(String field, boolean isAggregatable, String... types) {
Map<String, FieldCapabilities> caps = new HashMap<>();
for (String type : types) {
caps.put(type, new FieldCapabilities(field, type, true, true));
caps.put(type, new FieldCapabilities(field, type, true, isAggregatable));
}
fieldCaps.put(field, caps);
return this;