mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-26 06:46:10 +00:00
Feature importance storage format is changing to encompass multi-class. Feature importance objects are now mapped as follows (logistic) Regression: ``` { "feature_name": "feature_0", "importance": -1.3 } ``` Multi-class [class names are `foo`, `bar`, `baz`] ``` { “feature_name”: “feature_0”, “importance”: 2.0, // sum(abs()) of class importances “foo”: 1.0, “bar”: 0.5, “baz”: -0.5 }, ``` This change adjusts the mapping creation for analytics so that the field is mapped as a `nested` type. Native side change: https://github.com/elastic/ml-cpp/pull/1071
This commit is contained in:
parent
181bc807be
commit
d276058c6c
@ -288,9 +288,11 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
|
Map<String, Object> additionalProperties = new HashMap<>();
|
||||||
|
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
|
||||||
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
|
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
|
||||||
if ((dependentVariableMapping instanceof Map) == false) {
|
if ((dependentVariableMapping instanceof Map) == false) {
|
||||||
return Collections.emptyMap();
|
return additionalProperties;
|
||||||
}
|
}
|
||||||
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
|
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
|
||||||
// If the source field is an alias, fetch the concrete field that the alias points to.
|
// If the source field is an alias, fetch the concrete field that the alias points to.
|
||||||
@ -301,9 +303,8 @@ public class Classification implements DataFrameAnalysis {
|
|||||||
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
|
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
|
||||||
// Hence, we need to check the "instanceof" condition again.
|
// Hence, we need to check the "instanceof" condition again.
|
||||||
if ((dependentVariableMapping instanceof Map) == false) {
|
if ((dependentVariableMapping instanceof Map) == false) {
|
||||||
return Collections.emptyMap();
|
return additionalProperties;
|
||||||
}
|
}
|
||||||
Map<String, Object> additionalProperties = new HashMap<>();
|
|
||||||
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
|
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
|
||||||
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
|
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
|
||||||
return additionalProperties;
|
return additionalProperties;
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*//*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
||||||
|
|
||||||
|
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
||||||
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
final class MapUtils {
|
||||||
|
|
||||||
|
private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
|
||||||
|
static {
|
||||||
|
Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
|
||||||
|
featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
|
||||||
|
featureImportanceMappingProperties.put("importance",
|
||||||
|
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
|
||||||
|
Map<String, Object> featureImportanceMapping = new HashMap<>();
|
||||||
|
// TODO sorted indices don't support nested types
|
||||||
|
//featureImportanceMapping.put("dynamic", true);
|
||||||
|
//featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
|
||||||
|
featureImportanceMapping.put("properties", featureImportanceMappingProperties);
|
||||||
|
FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Map<String, Object> featureImportanceMapping() {
|
||||||
|
return FEATURE_IMPORTANCE_MAPPING;
|
||||||
|
}
|
||||||
|
|
||||||
|
private MapUtils() {}
|
||||||
|
}
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@ -187,9 +188,13 @@ public class Regression implements DataFrameAnalysis {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
|
||||||
|
Map<String, Object> additionalProperties = new HashMap<>();
|
||||||
|
additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
|
||||||
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
|
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
|
||||||
// high (over 10M) values of dependent variable.
|
// high (over 10M) values of dependent variable.
|
||||||
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
|
additionalProperties.put(resultsFieldName + "." + predictionFieldName,
|
||||||
|
Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
|
||||||
|
return additionalProperties;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -27,7 +27,6 @@ import java.util.Map;
|
|||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.allOf;
|
import static org.hamcrest.Matchers.allOf;
|
||||||
import static org.hamcrest.Matchers.anEmptyMap;
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.empty;
|
import static org.hamcrest.Matchers.empty;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
@ -244,31 +243,37 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void testGetExplicitlyMappedFields() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
|
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
|
||||||
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
||||||
|
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
|
||||||
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
||||||
assertThat(
|
assertThat(
|
||||||
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
|
||||||
is(anEmptyMap()));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
||||||
assertThat(
|
Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
||||||
new Classification("foo").getExplicitlyMappedFields(
|
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
||||||
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
|
"results");
|
||||||
"results"),
|
assertThat(explicitlyMappedFields,
|
||||||
allOf(
|
allOf(
|
||||||
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
|
||||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
|
||||||
assertThat(
|
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
|
||||||
new Classification("foo").getExplicitlyMappedFields(
|
|
||||||
new HashMap<String, Object>() {{
|
explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
|
||||||
put("foo", new HashMap<String, String>() {{
|
new HashMap<String, Object>() {{
|
||||||
put("type", "alias");
|
put("foo", new HashMap<String, String>() {{
|
||||||
put("path", "bar");
|
put("type", "alias");
|
||||||
}});
|
put("path", "bar");
|
||||||
put("bar", Collections.singletonMap("type", "long"));
|
}});
|
||||||
}},
|
put("bar", Collections.singletonMap("type", "long"));
|
||||||
"results"),
|
}},
|
||||||
|
"results");
|
||||||
|
assertThat(explicitlyMappedFields,
|
||||||
allOf(
|
allOf(
|
||||||
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
|
||||||
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
|
||||||
|
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
new Classification("foo").getExplicitlyMappedFields(
|
new Classification("foo").getExplicitlyMappedFields(
|
||||||
Collections.singletonMap("foo", new HashMap<String, String>() {{
|
Collections.singletonMap("foo", new HashMap<String, String>() {{
|
||||||
@ -276,7 +281,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|||||||
put("path", "missing");
|
put("path", "missing");
|
||||||
}}),
|
}}),
|
||||||
"results"),
|
"results"),
|
||||||
is(anEmptyMap()));
|
equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
|
||||||
|
@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
|
|||||||
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.allOf;
|
import static org.hamcrest.Matchers.allOf;
|
||||||
@ -143,9 +144,9 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void testGetExplicitlyMappedFields() {
|
public void testGetExplicitlyMappedFields() {
|
||||||
assertThat(
|
Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
|
||||||
new Regression("foo").getExplicitlyMappedFields(null, "results"),
|
assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
|
||||||
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
|
assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetStateDocId() {
|
public void testGetStateDocId() {
|
||||||
|
@ -77,7 +77,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||||||
cleanUp();
|
cleanUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
||||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||||
@ -109,7 +108,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||||
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
||||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||||
assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true));
|
@SuppressWarnings("unchecked")
|
||||||
|
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
|
||||||
|
assertThat(importanceArray, hasSize(greaterThan(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
|
@ -27,9 +27,11 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
|
||||||
import static org.hamcrest.Matchers.anyOf;
|
import static org.hamcrest.Matchers.anyOf;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.greaterThan;
|
import static org.hamcrest.Matchers.greaterThan;
|
||||||
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
@ -50,7 +52,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||||||
cleanUp();
|
cleanUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/53236")
|
|
||||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||||
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
||||||
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
||||||
@ -88,11 +89,13 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||||||
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
||||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||||
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
|
||||||
|
assertThat(importanceArray, hasSize(greaterThan(0)));
|
||||||
assertThat(
|
assertThat(
|
||||||
resultsObject.toString(),
|
importanceArray.stream().filter(m -> NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))
|
||||||
resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
|
|| DISCRETE_NUMERICAL_FEATURE_FIELD.equals(m.get("feature_name"))).findAny(),
|
||||||
|| resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
|
isPresent());
|
||||||
is(true));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user