[ML] adjusting feature importance mapping for multi-class support (#53821) (#54013)

Feature importance storage format is changing to encompass multi-class.

Feature importance objects are now mapped as follows
(logistic) Regression:
```
{
   "feature_name": "feature_0",
   "importance": -1.3
}
```
Multi-class [class names are `foo`, `bar`, `baz`]
```
{
   “feature_name”: “feature_0”,
   “importance”: 2.0, // sum(abs()) of class importances
   “foo”: 1.0,
   “bar”: 0.5,
   “baz”: -0.5
},
```

This change adjusts the mapping creation for analytics so that the field is mapped as a `nested` type.

Native side change: https://github.com/elastic/ml-cpp/pull/1071
This commit is contained in:
Benjamin Trent 2020-03-23 15:50:12 -04:00 committed by GitHub
parent 181bc807be
commit d276058c6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 89 additions and 33 deletions

View File

@ -288,9 +288,11 @@ public class Classification implements DataFrameAnalysis {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) { public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties); Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
if ((dependentVariableMapping instanceof Map) == false) { if ((dependentVariableMapping instanceof Map) == false) {
return Collections.emptyMap(); return additionalProperties;
} }
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping; Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
// If the source field is an alias, fetch the concrete field that the alias points to. // If the source field is an alias, fetch the concrete field that the alias points to.
@ -301,9 +303,8 @@ public class Classification implements DataFrameAnalysis {
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above. // We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
// Hence, we need to check the "instanceof" condition again. // Hence, we need to check the "instanceof" condition again.
if ((dependentVariableMapping instanceof Map) == false) { if ((dependentVariableMapping instanceof Map) == false) {
return Collections.emptyMap(); return additionalProperties;
} }
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping); additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping); additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
return additionalProperties; return additionalProperties;

View File

@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*//*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
import org.elasticsearch.index.mapper.KeywordFieldMapper;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
final class MapUtils {
private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
static {
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
featureImportanceMappingProperties.put("importance",
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
Map<String, Object> featureImportanceMapping = new HashMap<>();
// TODO sorted indices don't support nested types
//featureImportanceMapping.put("dynamic", true);
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
}
static Map<String, Object> featureImportanceMapping() {
return FEATURE_IMPORTANCE_MAPPING;
}
private MapUtils() {}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
@ -187,9 +188,13 @@ public class Regression implements DataFrameAnalysis {
@Override @Override
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) { public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
// high (over 10M) values of dependent variable. // high (over 10M) values of dependent variable.
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double")); additionalProperties.put(resultsFieldName + "." + predictionFieldName,
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
return additionalProperties;
} }
@Override @Override

View File

@ -27,7 +27,6 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -244,20 +243,23 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
} }
public void testGetExplicitlyMappedFields() { public void testGetExplicitlyMappedFields() {
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap())); assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap())); equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
assertThat( assertThat(
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"), new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
is(anEmptyMap())); equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
assertThat( Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")), Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
"results"), "results");
assertThat(explicitlyMappedFields,
allOf( allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")), hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz")))); hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
assertThat( assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
new Classification("foo").getExplicitlyMappedFields(
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
new HashMap<String, Object>() {{ new HashMap<String, Object>() {{
put("foo", new HashMap<String, String>() {{ put("foo", new HashMap<String, String>() {{
put("type", "alias"); put("type", "alias");
@ -265,10 +267,13 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
}}); }});
put("bar", Collections.singletonMap("type", "long")); put("bar", Collections.singletonMap("type", "long"));
}}, }},
"results"), "results");
assertThat(explicitlyMappedFields,
allOf( allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")), hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long")))); hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
assertThat( assertThat(
new Classification("foo").getExplicitlyMappedFields( new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", new HashMap<String, String>() {{ Collections.singletonMap("foo", new HashMap<String, String>() {{
@ -276,7 +281,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
put("path", "missing"); put("path", "missing");
}}), }}),
"results"), "results"),
is(anEmptyMap())); equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
} }
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

View File

@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
import java.util.Collections; import java.util.Collections;
import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.allOf;
@ -143,9 +144,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
} }
public void testGetExplicitlyMappedFields() { public void testGetExplicitlyMappedFields() {
assertThat( Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
new Regression("foo").getExplicitlyMappedFields(null, "results"), assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double"))); assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
} }
public void testGetStateDocId() { public void testGetStateDocId() {

View File

@ -77,7 +77,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
cleanUp(); cleanUp();
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("classification_single_numeric_feature_and_mixed_data_set"); initialize("classification_single_numeric_feature_and_mixed_data_set");
String predictedClassField = KEYWORD_FIELD + "_prediction"; String predictedClassField = KEYWORD_FIELD + "_prediction";
@ -109,7 +108,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true)); @SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
assertThat(importanceArray, hasSize(greaterThan(0)));
} }
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);

View File

@ -27,9 +27,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@ -50,7 +52,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
cleanUp(); cleanUp();
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
initialize("regression_single_numeric_feature_and_mixed_data_set"); initialize("regression_single_numeric_feature_and_mixed_data_set");
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction"; String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
@ -88,11 +89,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(resultsObject.containsKey(predictedClassField), is(true)); assertThat(resultsObject.containsKey(predictedClassField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
@SuppressWarnings("unchecked")
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
assertThat(importanceArray, hasSize(greaterThan(0)));
assertThat( assertThat(
resultsObject.toString(), importanceArray.stream().filter(m -> NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD) || DISCRETE_NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))).findAny(),
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD), isPresent());
is(true));
} }
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);