[ML] do not summerize importance for custom features (#60198) (#60333)

If a feature is created via a custom pre-processor,
we should return the importance for that feature.

This means we will not return the importance for the
original document field for custom processed features.

closes https://github.com/elastic/elasticsearch/issues/59330
This commit is contained in:
Benjamin Trent 2020-07-28 15:58:20 -04:00 committed by GitHub
parent 9776424062
commit 54c8936508
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 155 additions and 129 deletions

View File

@ -95,6 +95,7 @@ public class InferenceDefinition {
return decoderMap; return decoderMap;
} }
this.decoderMap = preProcessors.stream() this.decoderMap = preProcessors.stream()
.filter(p -> p.isCustom() == false)
.map(PreProcessor::reverseLookup) .map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll); .collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap; return decoderMap;

View File

@ -119,7 +119,7 @@ public class InferenceDefinitionTests extends ESTestCase {
public void testComplexInferenceDefinitionInfer() throws IOException { public void testComplexInferenceDefinitionInfer() throws IOException {
XContentParser parser = XContentHelper.createParser(xContentRegistry(), XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION, DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(CLASSIFICATION_DEFINITION), new BytesArray(getClassificationDefinition(false)),
XContentType.JSON); XContentType.JSON);
InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser); InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
@ -138,130 +138,155 @@ public class InferenceDefinitionTests extends ESTestCase {
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001)); assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
} }
public static final String CLASSIFICATION_DEFINITION = "{" + public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException {
" \"preprocessors\": [\n" + XContentParser parser = XContentHelper.createParser(xContentRegistry(),
" {\n" + DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
" \"one_hot_encoding\": {\n" + new BytesArray(getClassificationDefinition(true)),
" \"field\": \"col1\",\n" + XContentType.JSON);
" \"hot_map\": {\n" + InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" + ClassificationConfig config = new ClassificationConfig(2, null, null, 2, null);
" }\n" + Map<String, Object> featureMap = new HashMap<>();
" }\n" + featureMap.put("col1", "female");
" },\n" + featureMap.put("col2", "M");
" {\n" + featureMap.put("col3", "none");
" \"target_mean_encoding\": {\n" + featureMap.put("col4", 10);
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" + ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
" \"target_map\": {\n" + assertThat(results.valueAsString(), equalTo("second"));
" \"S\": 5.0,\n" + assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
" \"M\": 10.0,\n" + assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
" \"L\": 20\n" + assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
" },\n" + assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
" \"default_value\": 5.0\n" + }
" }\n" +
" },\n" + public static String getClassificationDefinition(boolean customPreprocessor) {
" {\n" + return "{" +
" \"frequency_encoding\": {\n" + " \"preprocessors\": [\n" +
" \"field\": \"col3\",\n" + " {\n" +
" \"feature_name\": \"col3_encoded\",\n" + " \"one_hot_encoding\": {\n" +
" \"frequency_map\": {\n" + " \"field\": \"col1\",\n" +
" \"none\": 0.75,\n" + " \"custom\": " + customPreprocessor + ",\n" +
" \"true\": 0.10,\n" + " \"hot_map\": {\n" +
" \"false\": 0.15\n" + " \"male\": \"col1_male\",\n" +
" }\n" + " \"female\": \"col1_female\"\n" +
" }\n" + " }\n" +
" }\n" + " }\n" +
" ],\n" + " },\n" +
" \"trained_model\": {\n" + " {\n" +
" \"ensemble\": {\n" + " \"target_mean_encoding\": {\n" +
" \"feature_names\": [\n" + " \"field\": \"col2\",\n" +
" \"col1_male\",\n" + " \"feature_name\": \"col2_encoded\",\n" +
" \"col1_female\",\n" + " \"target_map\": {\n" +
" \"col2_encoded\",\n" + " \"S\": 5.0,\n" +
" \"col3_encoded\",\n" + " \"M\": 10.0,\n" +
" \"col4\"\n" + " \"L\": 20\n" +
" ],\n" + " },\n" +
" \"aggregate_output\": {\n" + " \"default_value\": 5.0\n" +
" \"weighted_mode\": {\n" + " }\n" +
" \"num_classes\": \"2\",\n" + " },\n" +
" \"weights\": [\n" + " {\n" +
" 0.5,\n" + " \"frequency_encoding\": {\n" +
" 0.5\n" + " \"field\": \"col3\",\n" +
" ]\n" + " \"feature_name\": \"col3_encoded\",\n" +
" }\n" + " \"frequency_map\": {\n" +
" },\n" + " \"none\": 0.75,\n" +
" \"target_type\": \"classification\",\n" + " \"true\": 0.10,\n" +
" \"classification_labels\": [\"first\", \"second\"],\n" + " \"false\": 0.15\n" +
" \"trained_models\": [\n" + " }\n" +
" {\n" + " }\n" +
" \"tree\": {\n" + " }\n" +
" \"feature_names\": [\n" + " ],\n" +
" \"col1_male\",\n" + " \"trained_model\": {\n" +
" \"col1_female\",\n" + " \"ensemble\": {\n" +
" \"col4\"\n" + " \"feature_names\": [\n" +
" ],\n" + " \"col1_male\",\n" +
" \"tree_structure\": [\n" + " \"col1_female\",\n" +
" {\n" + " \"col2_encoded\",\n" +
" \"node_index\": 0,\n" + " \"col3_encoded\",\n" +
" \"split_feature\": 0,\n" + " \"col4\"\n" +
" \"number_samples\": 100,\n" + " ],\n" +
" \"split_gain\": 12.0,\n" + " \"aggregate_output\": {\n" +
" \"threshold\": 10.0,\n" + " \"weighted_mode\": {\n" +
" \"decision_type\": \"lte\",\n" + " \"num_classes\": \"2\",\n" +
" \"default_left\": true,\n" + " \"weights\": [\n" +
" \"left_child\": 1,\n" + " 0.5,\n" +
" \"right_child\": 2\n" + " 0.5\n" +
" },\n" + " ]\n" +
" {\n" + " }\n" +
" \"node_index\": 1,\n" + " },\n" +
" \"number_samples\": 80,\n" + " \"target_type\": \"classification\",\n" +
" \"leaf_value\": 1\n" + " \"classification_labels\": [\"first\", \"second\"],\n" +
" },\n" + " \"trained_models\": [\n" +
" {\n" + " {\n" +
" \"node_index\": 2,\n" + " \"tree\": {\n" +
" \"number_samples\": 20,\n" + " \"feature_names\": [\n" +
" \"leaf_value\": 0\n" + " \"col1_male\",\n" +
" }\n" + " \"col1_female\",\n" +
" ],\n" + " \"col4\"\n" +
" \"target_type\": \"regression\"\n" + " ],\n" +
" }\n" + " \"tree_structure\": [\n" +
" },\n" + " {\n" +
" {\n" + " \"node_index\": 0,\n" +
" \"tree\": {\n" + " \"split_feature\": 0,\n" +
" \"feature_names\": [\n" + " \"number_samples\": 100,\n" +
" \"col2_encoded\",\n" + " \"split_gain\": 12.0,\n" +
" \"col3_encoded\",\n" + " \"threshold\": 10.0,\n" +
" \"col4\"\n" + " \"decision_type\": \"lte\",\n" +
" ],\n" + " \"default_left\": true,\n" +
" \"tree_structure\": [\n" + " \"left_child\": 1,\n" +
" {\n" + " \"right_child\": 2\n" +
" \"node_index\": 0,\n" + " },\n" +
" \"split_feature\": 0,\n" + " {\n" +
" \"split_gain\": 12.0,\n" + " \"node_index\": 1,\n" +
" \"number_samples\": 180,\n" + " \"number_samples\": 80,\n" +
" \"threshold\": 10.0,\n" + " \"leaf_value\": 1\n" +
" \"decision_type\": \"lte\",\n" + " },\n" +
" \"default_left\": true,\n" + " {\n" +
" \"left_child\": 1,\n" + " \"node_index\": 2,\n" +
" \"right_child\": 2\n" + " \"number_samples\": 20,\n" +
" },\n" + " \"leaf_value\": 0\n" +
" {\n" + " }\n" +
" \"node_index\": 1,\n" + " ],\n" +
" \"number_samples\": 10,\n" + " \"target_type\": \"regression\"\n" +
" \"leaf_value\": 1\n" + " }\n" +
" },\n" + " },\n" +
" {\n" + " {\n" +
" \"node_index\": 2,\n" + " \"tree\": {\n" +
" \"number_samples\": 170,\n" + " \"feature_names\": [\n" +
" \"leaf_value\": 0\n" + " \"col2_encoded\",\n" +
" }\n" + " \"col3_encoded\",\n" +
" ],\n" + " \"col4\"\n" +
" \"target_type\": \"regression\"\n" + " ],\n" +
" }\n" + " \"tree_structure\": [\n" +
" }\n" + " {\n" +
" ]\n" + " \"node_index\": 0,\n" +
" }\n" + " \"split_feature\": 0,\n" +
" }\n" + " \"split_gain\": 12.0,\n" +
"}"; " \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
}
} }

View File

@ -34,7 +34,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.CLASSIFICATION_DEFINITION; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.getClassificationDefinition;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
@ -532,7 +532,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"description\": \"test model for classification\",\n" + " \"description\": \"test model for classification\",\n" +
" \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" + " \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" +
" \"inference_config\": {\"classification\": {}},\n" + " \"inference_config\": {\"classification\": {}},\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION + " \"definition\": " + getClassificationDefinition(false) +
"}"; "}";
private static String pipelineDefinition(String modelId, String inferenceConfig) { private static String pipelineDefinition(String modelId, String inferenceConfig) {