diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index ef502655e49..d2d3a8ad4be 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; @@ -293,20 +294,31 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public void testWithCustomFeatureProcessors() throws Exception { initialize("classification_with_custom_feature_processors"); String predictedClassField = KEYWORD_FIELD + "_prediction"; - indexData(sourceIndex, 300, 50, KEYWORD_FIELD); + indexData(sourceIndex, 100, 0, KEYWORD_FIELD); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification( KEYWORD_FIELD, - BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), - null, - null, - null, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(0).build(), null, null, + 2, + 10.0, + 42L, Arrays.asList( - new OneHotEncoding(TEXT_FIELD, Collections.singletonMap(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom"), true) + new OneHotEncoding(ALIAS_TO_KEYWORD_FIELD, MapBuilder.newMapBuilder() + .put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom") + .put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom").map(), true), + new OneHotEncoding(ALIAS_TO_NESTED_FIELD, MapBuilder.newMapBuilder() + .put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_1") + .put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_1").map(), true), + new OneHotEncoding(NESTED_FIELD, MapBuilder.newMapBuilder() + .put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_2") + .put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_2").map(), true), + new OneHotEncoding(TEXT_FIELD, MapBuilder.newMapBuilder() + .put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3") + .put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true) ))); putAnalytics(config); @@ -322,11 +334,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { Map destDoc = getDestDoc(config, hit); Map resultsObject = getFieldValue(destDoc, "ml"); assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); - assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); - @SuppressWarnings("unchecked") - List> importanceArray = (List>)resultsObject.get("feature_importance"); - assertThat(importanceArray, hasSize(greaterThan(0))); } assertProgressComplete(jobId); @@ -354,9 +362,13 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { TrainedModelConfig modelConfig = response.getResources().results().get(0); modelConfig.ensureParsedDefinition(xContentRegistry()); assertThat(modelConfig.getModelDefinition().getPreProcessors().size(), greaterThan(0)); - for (int i = 0; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) { + for (int i = 0; i < 4; i++) { PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i); - assertThat(preProcessor.isCustom(), equalTo(i == 0)); + assertThat(preProcessor.isCustom(), is(true)); + } + for (int i = 4; i < modelConfig.getModelDefinition().getPreProcessors().size(); i++) { + PreProcessor preProcessor = modelConfig.getModelDefinition().getPreProcessors().get(i); + assertThat(preProcessor.isCustom(), is(false)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index aed895cd571..015677dd3ca 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -77,17 +77,8 @@ public class DataFrameDataExtractor { DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) { this.client = Objects.requireNonNull(client); this.context = Objects.requireNonNull(context); - Set processedFieldInputs = context.extractedFields.getProcessedFieldInputs(); - this.organicFeatures = context.extractedFields.getAllFields() - .stream() - .map(ExtractedField::getName) - .filter(f -> processedFieldInputs.contains(f) == false) - .toArray(String[]::new); - this.processedFeatures = context.extractedFields.getProcessedFields() - .stream() - .map(ProcessedField::getOutputFieldNames) - .flatMap(List::stream) - .toArray(String[]::new); + this.organicFeatures = context.extractedFields.extractOrganicFeatureNames(); + this.processedFeatures = context.extractedFields.extractProcessedFeatureNames(); this.extractedFieldsByName = new LinkedHashMap<>(); context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f)); hasNext = true; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 353491aac8d..06898c925fd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -13,7 +13,6 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; -import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.util.Arrays; @@ -22,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; +import java.util.stream.Stream; public class DataFrameDataExtractorFactory { @@ -94,8 +94,13 @@ public class DataFrameDataExtractorFactory { private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, ExtractedFields extractedFields) { - return new TrainTestSplitterFactory(client, config, - extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList())); + return new TrainTestSplitterFactory( + client, + config, + Stream.concat( + Arrays.stream(extractedFields.extractOrganicFeatureNames()), + Arrays.stream(extractedFields.extractProcessedFeatureNames()) + ).collect(Collectors.toList())); } /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java index 3853ea2629a..84171d83078 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java @@ -66,6 +66,23 @@ public class ExtractedFields { return cardinalitiesForFieldsWithConstraints; } + public String[] extractOrganicFeatureNames() { + Set processedFieldInputs = getProcessedFieldInputs(); + return allFields + .stream() + .map(ExtractedField::getName) + .filter(f -> processedFieldInputs.contains(f) == false) + .toArray(String[]::new); + } + + public String[] extractProcessedFeatureNames() { + return processedFields + .stream() + .map(ProcessedField::getOutputFieldNames) + .flatMap(List::stream) + .toArray(String[]::new); + } + private static List filterFields(ExtractedField.Method method, List fields) { return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java index d5c27f78103..f92632536cb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java @@ -9,15 +9,19 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.Map; import java.util.TreeSet; +import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -133,6 +137,33 @@ public class ExtractedFieldsTests extends ESTestCase { assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); } + public void testExtractFeatureOrganicAndProcessedNames() { + ExtractedField docValue1 = new DocValueField("doc1", Collections.singleton("keyword")); + ExtractedField docValue2 = new DocValueField("doc2", Collections.singleton("ip")); + ExtractedField scriptField1 = new ScriptField("scripted1"); + ExtractedField scriptField2 = new ScriptField("scripted2"); + ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); + ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); + + Map hotMap = new LinkedHashMap<>(); + hotMap.put("bar", "bar_column"); + hotMap.put("foo", "foo_column"); + + ExtractedFields extractedFields = new ExtractedFields( + Arrays.asList(docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), + Arrays.asList( + new ProcessedField(new NGram("doc1", "f", new int[] {1 , 2}, 0, 2, true)), + new ProcessedField(new OneHotEncoding("src1", hotMap, true))), + Collections.emptyMap()); + + + String[] organic = extractedFields.extractOrganicFeatureNames(); + assertThat(organic, arrayContaining("doc2", "scripted1", "scripted2", "src2")); + + String[] processed = extractedFields.extractProcessedFeatureNames(); + assertThat(processed, arrayContaining("f.10", "f.11", "f.20", "bar_column", "foo_column")); + } + private static FieldCapabilities createFieldCaps(boolean isAggregatable) { FieldCapabilities fieldCaps = mock(FieldCapabilities.class); when(fieldCaps.isAggregatable()).thenReturn(isAggregatable);