mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-17 10:25:15 +00:00
[ML] Fix bug with data frame analytics classification test data sampling when using custom feature processors (#64727) (#64864)
When using custom processors, the field names extracted from the documents are not the same as the feature names used for training. Consequently, it is possible for the stratified sampler to have an incorrect view of the feature rows. This can lead to the wrong column being read for the class label, and thus throw errors on training row extraction. This commit changes the training row feature names used by the stratified sampler so that it matches the names (and their order) that are sent to the analytics process.
This commit is contained in:
parent
dafafd7ec6
commit
f0ff673f82
@ -18,6 +18,7 @@ import org.elasticsearch.action.index.IndexAction;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.collect.MapBuilder;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
@ -293,20 +294,31 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
public void testWithCustomFeatureProcessors() throws Exception {
|
||||
initialize("classification_with_custom_feature_processors");
|
||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 300, 50, KEYWORD_FIELD);
|
||||
indexData(sourceIndex, 100, 0, KEYWORD_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config =
|
||||
buildAnalytics(jobId, sourceIndex, destIndex, null,
|
||||
new Classification(
|
||||
KEYWORD_FIELD,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
BoostedTreeParams.builder().setNumTopFeatureImportanceValues(0).build(),
|
||||
null,
|
||||
null,
|
||||
2,
|
||||
10.0,
|
||||
42L,
|
||||
Arrays.asList(
|
||||
new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true)
|
||||
new OneHotEncoding(ALIAS_TO_KEYWORD_FIELD, MapBuilder.<String, String>newMapBuilder()
|
||||
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom")
|
||||
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom").map(), true),
|
||||
new OneHotEncoding(ALIAS_TO_NESTED_FIELD, MapBuilder.<String, String>newMapBuilder()
|
||||
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_1")
|
||||
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_1").map(), true),
|
||||
new OneHotEncoding(NESTED_FIELD, MapBuilder.<String, String>newMapBuilder()
|
||||
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_2")
|
||||
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_2").map(), true),
|
||||
new OneHotEncoding(TEXT_FIELD, MapBuilder.<String, String>newMapBuilder()
|
||||
.put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3")
|
||||
.put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true)
|
||||
)));
|
||||
putAnalytics(config);
|
||||
|
||||
@ -322,11 +334,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
|
||||
assertThat(importanceArray, hasSize(greaterThan(0)));
|
||||
}
|
||||
|
||||
assertProgressComplete(jobId);
|
||||
@ -354,9 +362,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
TrainedModelConfig modelConfig = response.getResources().results().get(0);
|
||||
modelConfig.ensureParsedDefinition(xContentRegistry());
|
||||
assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0));
|
||||
for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
|
||||
assertThat(preProcessor.isCustom(), equalTo(i == 0));
|
||||
assertThat(preProcessor.isCustom(), is(true));
|
||||
}
|
||||
for (int i = 4; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) {
|
||||
PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i);
|
||||
assertThat(preProcessor.isCustom(), is(false));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,17 +77,8 @@ public class DataFrameDataExtractor {
|
||||
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.context = Objects.requireNonNull(context);
|
||||
Set<String> processedFieldInputs = context.extractedFields.getProcessedFieldInputs();
|
||||
this.organicFeatures = context.extractedFields.getAllFields()
|
||||
.stream()
|
||||
.map(ExtractedField::getName)
|
||||
.filter(f -> processedFieldInputs.contains(f) == false)
|
||||
.toArray(String[]::new);
|
||||
this.processedFeatures = context.extractedFields.getProcessedFields()
|
||||
.stream()
|
||||
.map(ProcessedField::getOutputFieldNames)
|
||||
.flatMap(List::stream)
|
||||
.toArray(String[]::new);
|
||||
this.organicFeatures = context.extractedFields.extractOrganicFeatureNames();
|
||||
this.processedFeatures = context.extractedFields.extractProcessedFeatureNames();
|
||||
this.extractedFieldsByName = new LinkedHashMap<>();
|
||||
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
|
||||
hasNext = true;
|
||||
|
@ -13,7 +13,6 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -22,6 +21,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class DataFrameDataExtractorFactory {
|
||||
|
||||
@ -94,8 +94,13 @@ public class DataFrameDataExtractorFactory {
|
||||
|
||||
private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config,
|
||||
ExtractedFields extractedFields) {
|
||||
return new TrainTestSplitterFactory(client, config,
|
||||
extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()));
|
||||
return new TrainTestSplitterFactory(
|
||||
client,
|
||||
config,
|
||||
Stream.concat(
|
||||
Arrays.stream(extractedFields.extractOrganicFeatureNames()),
|
||||
Arrays.stream(extractedFields.extractProcessedFeatureNames())
|
||||
).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -66,6 +66,23 @@ public class ExtractedFields {
|
||||
return cardinalitiesForFieldsWithConstraints;
|
||||
}
|
||||
|
||||
public String[] extractOrganicFeatureNames() {
|
||||
Set<String> processedFieldInputs = getProcessedFieldInputs();
|
||||
return allFields
|
||||
.stream()
|
||||
.map(ExtractedField::getName)
|
||||
.filter(f -> processedFieldInputs.contains(f) == false)
|
||||
.toArray(String[]::new);
|
||||
}
|
||||
|
||||
public String[] extractProcessedFeatureNames() {
|
||||
return processedFields
|
||||
.stream()
|
||||
.map(ProcessedField::getOutputFieldNames)
|
||||
.flatMap(List::stream)
|
||||
.toArray(String[]::new);
|
||||
}
|
||||
|
||||
private static List<ExtractedField> filterFields(ExtractedField.Method method, List<ExtractedField> fields) {
|
||||
return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList());
|
||||
}
|
||||
|
@ -9,15 +9,19 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilities;
|
||||
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
|
||||
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
|
||||
import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.hamcrest.Matchers.arrayContaining;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
@ -133,6 +137,33 @@ public class ExtractedFieldsTests extends ESTestCase {
|
||||
assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings"));
|
||||
}
|
||||
|
||||
public void testExtractFeatureOrganicAndProcessedNames() {
|
||||
ExtractedField docValue1 = new DocValueField("doc1", Collections.singleton("keyword"));
|
||||
ExtractedField docValue2 = new DocValueField("doc2", Collections.singleton("ip"));
|
||||
ExtractedField scriptField1 = new ScriptField("scripted1");
|
||||
ExtractedField scriptField2 = new ScriptField("scripted2");
|
||||
ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text"));
|
||||
ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text"));
|
||||
|
||||
Map<String, String> hotMap = new LinkedHashMap<>();
|
||||
hotMap.put("bar", "bar_column");
|
||||
hotMap.put("foo", "foo_column");
|
||||
|
||||
ExtractedFields extractedFields = new ExtractedFields(
|
||||
Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2),
|
||||
Arrays.asList(
|
||||
new ProcessedField(new NGram("doc1", "f", new int[] {1 , 2}, 0, 2, true)),
|
||||
new ProcessedField(new OneHotEncoding("src1", hotMap, true))),
|
||||
Collections.emptyMap());
|
||||
|
||||
|
||||
String[] organic = extractedFields.extractOrganicFeatureNames();
|
||||
assertThat(organic, arrayContaining("doc2", "scripted1", "scripted2", "src2"));
|
||||
|
||||
String[] processed = extractedFields.extractProcessedFeatureNames();
|
||||
assertThat(processed, arrayContaining("f.10", "f.11", "f.20", "bar_column", "foo_column"));
|
||||
}
|
||||
|
||||
private static FieldCapabilities createFieldCaps(boolean isAggregatable) {
|
||||
FieldCapabilities fieldCaps = mock(FieldCapabilities.class);
|
||||
when(fieldCaps.isAggregatable()).thenReturn(isAggregatable);
|
||||
|
Loading…
x
Reference in New Issue
Block a user