If a feature is created via a custom pre-processor, we should return the importance for that feature. This means we will not return the importance for the original document field for custom processed features. closes https://github.com/elastic/elasticsearch/issues/59330
This commit is contained in:
parent
9776424062
commit
54c8936508
|
@ -95,6 +95,7 @@ public class InferenceDefinition {
|
|||
return decoderMap;
|
||||
}
|
||||
this.decoderMap = preProcessors.stream()
|
||||
.filter(p -> p.isCustom() == false)
|
||||
.map(PreProcessor::reverseLookup)
|
||||
.collect(HashMap::new, Map::putAll, Map::putAll);
|
||||
return decoderMap;
|
||||
|
|
|
@ -119,7 +119,7 @@ public class InferenceDefinitionTests extends ESTestCase {
|
|||
public void testComplexInferenceDefinitionInfer() throws IOException {
|
||||
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(CLASSIFICATION_DEFINITION),
|
||||
new BytesArray(getClassificationDefinition(false)),
|
||||
XContentType.JSON);
|
||||
InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
|
||||
|
||||
|
@ -138,130 +138,155 @@ public class InferenceDefinitionTests extends ESTestCase {
|
|||
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
|
||||
}
|
||||
|
||||
public static final String CLASSIFICATION_DEFINITION = "{" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"col1\",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"male\": \"col1_male\",\n" +
|
||||
" \"female\": \"col1_female\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field\": \"col2\",\n" +
|
||||
" \"feature_name\": \"col2_encoded\",\n" +
|
||||
" \"target_map\": {\n" +
|
||||
" \"S\": 5.0,\n" +
|
||||
" \"M\": 10.0,\n" +
|
||||
" \"L\": 20\n" +
|
||||
" },\n" +
|
||||
" \"default_value\": 5.0\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"col3\",\n" +
|
||||
" \"feature_name\": \"col3_encoded\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"none\": 0.75,\n" +
|
||||
" \"true\": 0.10,\n" +
|
||||
" \"false\": 0.15\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"ensemble\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"aggregate_output\": {\n" +
|
||||
" \"weighted_mode\": {\n" +
|
||||
" \"num_classes\": \"2\",\n" +
|
||||
" \"weights\": [\n" +
|
||||
" 0.5,\n" +
|
||||
" 0.5\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"target_type\": \"classification\",\n" +
|
||||
" \"classification_labels\": [\"first\", \"second\"],\n" +
|
||||
" \"trained_models\": [\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"number_samples\": 100,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 80,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 20,\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"number_samples\": 180,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 10,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 170,\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException {
|
||||
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
||||
new BytesArray(getClassificationDefinition(true)),
|
||||
XContentType.JSON);
|
||||
InferenceDefinition inferenceDefinition = InferenceDefinition.fromXContent(parser);
|
||||
|
||||
ClassificationConfig config = new ClassificationConfig(2, null, null, 2, null);
|
||||
Map<String, Object> featureMap = new HashMap<>();
|
||||
featureMap.put("col1", "female");
|
||||
featureMap.put("col2", "M");
|
||||
featureMap.put("col3", "none");
|
||||
featureMap.put("col4", 10);
|
||||
|
||||
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
|
||||
assertThat(results.valueAsString(), equalTo("second"));
|
||||
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
|
||||
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
|
||||
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
|
||||
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
|
||||
}
|
||||
|
||||
public static String getClassificationDefinition(boolean customPreprocessor) {
|
||||
return "{" +
|
||||
" \"preprocessors\": [\n" +
|
||||
" {\n" +
|
||||
" \"one_hot_encoding\": {\n" +
|
||||
" \"field\": \"col1\",\n" +
|
||||
" \"custom\": " + customPreprocessor + ",\n" +
|
||||
" \"hot_map\": {\n" +
|
||||
" \"male\": \"col1_male\",\n" +
|
||||
" \"female\": \"col1_female\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"target_mean_encoding\": {\n" +
|
||||
" \"field\": \"col2\",\n" +
|
||||
" \"feature_name\": \"col2_encoded\",\n" +
|
||||
" \"target_map\": {\n" +
|
||||
" \"S\": 5.0,\n" +
|
||||
" \"M\": 10.0,\n" +
|
||||
" \"L\": 20\n" +
|
||||
" },\n" +
|
||||
" \"default_value\": 5.0\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"frequency_encoding\": {\n" +
|
||||
" \"field\": \"col3\",\n" +
|
||||
" \"feature_name\": \"col3_encoded\",\n" +
|
||||
" \"frequency_map\": {\n" +
|
||||
" \"none\": 0.75,\n" +
|
||||
" \"true\": 0.10,\n" +
|
||||
" \"false\": 0.15\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"trained_model\": {\n" +
|
||||
" \"ensemble\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"aggregate_output\": {\n" +
|
||||
" \"weighted_mode\": {\n" +
|
||||
" \"num_classes\": \"2\",\n" +
|
||||
" \"weights\": [\n" +
|
||||
" 0.5,\n" +
|
||||
" 0.5\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"target_type\": \"classification\",\n" +
|
||||
" \"classification_labels\": [\"first\", \"second\"],\n" +
|
||||
" \"trained_models\": [\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col1_male\",\n" +
|
||||
" \"col1_female\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"number_samples\": 100,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 80,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 20,\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"tree\": {\n" +
|
||||
" \"feature_names\": [\n" +
|
||||
" \"col2_encoded\",\n" +
|
||||
" \"col3_encoded\",\n" +
|
||||
" \"col4\"\n" +
|
||||
" ],\n" +
|
||||
" \"tree_structure\": [\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 0,\n" +
|
||||
" \"split_feature\": 0,\n" +
|
||||
" \"split_gain\": 12.0,\n" +
|
||||
" \"number_samples\": 180,\n" +
|
||||
" \"threshold\": 10.0,\n" +
|
||||
" \"decision_type\": \"lte\",\n" +
|
||||
" \"default_left\": true,\n" +
|
||||
" \"left_child\": 1,\n" +
|
||||
" \"right_child\": 2\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 1,\n" +
|
||||
" \"number_samples\": 10,\n" +
|
||||
" \"leaf_value\": 1\n" +
|
||||
" },\n" +
|
||||
" {\n" +
|
||||
" \"node_index\": 2,\n" +
|
||||
" \"number_samples\": 170,\n" +
|
||||
" \"leaf_value\": 0\n" +
|
||||
" }\n" +
|
||||
" ],\n" +
|
||||
" \"target_type\": \"regression\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" ]\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.CLASSIFICATION_DEFINITION;
|
||||
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinitionTests.getClassificationDefinition;
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
import static org.hamcrest.CoreMatchers.containsString;
|
||||
|
||||
|
@ -532,7 +532,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
|||
" \"description\": \"test model for classification\",\n" +
|
||||
" \"default_field_map\": {\"col_1_alias\": \"col1\"},\n" +
|
||||
" \"inference_config\": {\"classification\": {}},\n" +
|
||||
" \"definition\": " + CLASSIFICATION_DEFINITION +
|
||||
" \"definition\": " + getClassificationDefinition(false) +
|
||||
"}";
|
||||
|
||||
private static String pipelineDefinition(String modelId, String inferenceConfig) {
|
||||
|
|
Loading…
Reference in New Issue