[ML] Fix bug with data frame analytics classification test data sampling when using custom feature processors (#64727) (#64864)

When using custom processors, the field names extracted from the documents are not the
same as the feature names used for training.

Consequently, it is possible for the stratified sampler to have an incorrect view of the feature rows.
This can lead to the wrong column being read for the class label, and thus throw errors on training
row extraction.

This commit changes the training row feature names used by the stratified sampler so that it matches
the names (and their order) that are sent to the analytics process.
This commit is contained in:
Benjamin Trent 2020-11-10 08:47:07 -05:00 committed by GitHub
parent dafafd7ec6
commit f0ff673f82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 26 deletions

View File

@ -18,6 +18,7 @@ import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
@ -293,20 +294,31 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
public void testWithCustomFeatureProcessors() throws Exception {
initialize("classification_with_custom_feature_processors");
String predictedClassField = KEYWORD_FIELD + "_prediction";
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);
DataFrameAnalyticsConfig config =
buildAnalytics(jobId, sourceIndex, destIndex, null,
new Classification(
KEYWORD_FIELD,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
null,
null,
null,
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(0).build(),
null,
null,
2,
10.0,
42L,
Arrays.asList(
new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true)
new OneHotEncoding(ALIAS_TO_KEYWORD_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom").map(), true),
new OneHotEncoding(ALIAS_TO_NESTED_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_1")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_1").map(), true),
new OneHotEncoding(NESTED_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_2")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_2").map(), true),
new OneHotEncoding(TEXT_FIELD, MapBuilder.<String, String>newMapBuilder()
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3")
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true)
)));
putAnalytics(config);
@ -322,11 +334,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
assertThat(importanceArray, hasSize(greaterThan(0)));
}
assertProgressComplete(jobId);
@ -354,9 +362,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
TrainedModelConfig modelConfig = response.getResources().results().get(0);
modelConfig.ensureParsedDefinition(xContentRegistry());
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
for (int i = 0; i < 4; i++) {
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
assertThat(preProcessor.isCustom(), equalTo(i == 0));
assertThat(preProcessor.isCustom(), is(true));
}
for (int i = 4; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
assertThat(preProcessor.isCustom(), is(false));
}
}

View File

@ -77,17 +77,8 @@ public class DataFrameDataExtractor {
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
this.client = Objects.requireNonNull(client);
this.context = Objects.requireNonNull(context);
Set<String> processedFieldInputs = context.extractedFields.getProcessedFieldInputs();
this.organicFeatures = context.extractedFields.getAllFields()
.stream()
.map(ExtractedField::getName)
.filter(f -> processedFieldInputs.contains(f) == false)
.toArray(String[]::new);
this.processedFeatures = context.extractedFields.getProcessedFields()
.stream()
.map(ProcessedField::getOutputFieldNames)
.flatMap(List::stream)
.toArray(String[]::new);
this.organicFeatures = context.extractedFields.extractOrganicFeatureNames();
this.processedFeatures = context.extractedFields.extractProcessedFeatureNames();
this.extractedFieldsByName = new LinkedHashMap<>();
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
hasNext = true;

View File

@ -13,7 +13,6 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.util.Arrays;
@ -22,6 +21,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class DataFrameDataExtractorFactory {
@ -94,8 +94,13 @@ public class DataFrameDataExtractorFactory {
private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new TrainTestSplitterFactory(client, config,
extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()));
return new TrainTestSplitterFactory(
client,
config,
Stream.concat(
Arrays.stream(extractedFields.extractOrganicFeatureNames()),
Arrays.stream(extractedFields.extractProcessedFeatureNames())
).collect(Collectors.toList()));
}
/**

View File

@ -66,6 +66,23 @@ public class ExtractedFields {
return cardinalitiesForFieldsWithConstraints;
}
public String[] extractOrganicFeatureNames() {
Set<String> processedFieldInputs = getProcessedFieldInputs();
return allFields
.stream()
.map(ExtractedField::getName)
.filter(f -> processedFieldInputs.contains(f) == false)
.toArray(String[]::new);
}
public String[] extractProcessedFeatureNames() {
return processedFields
.stream()
.map(ProcessedField::getOutputFieldNames)
.flatMap(List::stream)
.toArray(String[]::new);
}
private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
}

View File

@ -9,15 +9,19 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilities;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.TreeSet;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
@ -133,6 +137,33 @@ public class ExtractedFieldsTests extends ESTestCase {
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
}
public void testExtractFeatureOrganicAndProcessedNames() {
ExtractedField docValue1 = new DocValueField("doc1", Collections.singleton("keyword"));
ExtractedField docValue2 = new DocValueField("doc2", Collections.singleton("ip"));
ExtractedField scriptField1 = new ScriptField("scripted1");
ExtractedField scriptField2 = new ScriptField("scripted2");
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
Map<String, String> hotMap = new LinkedHashMap<>();
hotMap.put("bar", "bar_column");
hotMap.put("foo", "foo_column");
ExtractedFields extractedFields = new ExtractedFields(
Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2),
Arrays.asList(
new ProcessedField(new NGram("doc1", "f", new int[] {1 , 2}, 0, 2, true)),
new ProcessedField(new OneHotEncoding("src1", hotMap, true))),
Collections.emptyMap());
String[] organic = extractedFields.extractOrganicFeatureNames();
assertThat(organic, arrayContaining("doc2", "scripted1", "scripted2", "src2"));
String[] processed = extractedFields.extractProcessedFeatureNames();
assertThat(processed, arrayContaining("f.10", "f.11", "f.20", "bar_column", "foo_column"));
}
private static FieldCapabilities createFieldCaps(boolean isAggregatable) {
FieldCapabilities fieldCaps = mock(FieldCapabilities.class);
when(fieldCaps.isAggregatable()).thenReturn(isAggregatable);