[ML] do not summerize importance for custom features (#60198) (#60333)

If a feature is created via a custom pre-processor,
we should return the importance for that feature.

This means we will not return the importance for the
original document field for custom processed features.

closes https://github.com/elastic/elasticsearch/issues/59330
This commit is contained in:
Benjamin Trent 2020-07-28 15:58:20 -04:00 committed by GitHub
parent 9776424062
commit 54c8936508
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 155 additions and 129 deletions

View File

@ -95,6 +95,7 @@ public class InferenceDefinition {
return decoderMap;
}
this.decoderMap = preProcessors.stream()
.filter(p -> p.isCustom() == false)
.map(PreProcessor::reverseLookup)
.collect(HashMap::new, Map::putAll, Map::putAll);
return decoderMap;

View File

@ -119,7 +119,7 @@ public class InferenceDefinitionTests extends ESTestCase {
public void testComplexInferenceDefinitionInfer() throws IOException {
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(CLASSIFICATION_DEFINITION),
new BytesArray(getClassificationDefinition(false)),
XContentType.JSON);
InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
@ -138,130 +138,155 @@ public class InferenceDefinitionTests extends ESTestCase {
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
}
public static final String CLASSIFICATION_DEFINITION = "{" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_mode\": {\n" +
" \"num_classes\": \"2\",\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"classification\",\n" +
" \"classification_labels\": [\"first\", \"second\"],\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"number_samples\": 100,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 80,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 20,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException {
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(getClassificationDefinition(true)),
XContentType.JSON);
InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
ClassificationConfig config = new ClassificationConfig(2, null, null, 2, null);
Map<String, Object> featureMap = new HashMap<>();
featureMap.put("col1", "female");
featureMap.put("col2", "M");
featureMap.put("col3", "none");
featureMap.put("col4", 10);
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
assertThat(results.valueAsString(), equalTo("second"));
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
}
public static String getClassificationDefinition(boolean customPreprocessor) {
return "{" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"custom\": " + customPreprocessor + ",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_mode\": {\n" +
" \"num_classes\": \"2\",\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"classification\",\n" +
" \"classification_labels\": [\"first\", \"second\"],\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"number_samples\": 100,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 80,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 20,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"number_samples\": 180,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"number_samples\": 10,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"number_samples\": 170,\n" +
" \"leaf_value\": 0\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
}
}

View File

@ -34,7 +34,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.CLASSIFICATION_DEFINITION;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.getClassificationDefinition;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.CoreMatchers.containsString;
@ -532,7 +532,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"description\": \"test model for classification\",\n" +
" \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" +
" \"inference_config\": {\"classification\": {}},\n" +
" \"definition\": " + CLASSIFICATION_DEFINITION +
" \"definition\": " + getClassificationDefinition(false) +
"}";
private static String pipelineDefinition(String modelId, String inferenceConfig) {