diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 7d88270adce..4d5b9ffefcd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -314,7 +314,7 @@ public class Classification implements DataFrameAnalysis { @Override public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { Map additionalProperties = new HashMap<>(); - additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping()); + additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping()); Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties); if ((dependentVariableMapping instanceof Map) == false) { return additionalProperties; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java index 5440bc850c2..3cc8825944f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java @@ -18,22 +18,46 @@ import java.util.Map; final class MapUtils { - private static final Map FEATURE_IMPORTANCE_MAPPING; - static { - Map featureImportanceMappingProperties = new HashMap<>(); + private static Map createFeatureImportanceMapping(Map featureImportanceMappingProperties){ featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); - featureImportanceMappingProperties.put("importance", - Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); Map featureImportanceMapping = new HashMap<>(); // TODO sorted indices don't support nested types //featureImportanceMapping.put("dynamic", true); //featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); featureImportanceMapping.put("properties", featureImportanceMappingProperties); - FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping); + return featureImportanceMapping; } - static Map featureImportanceMapping() { - return FEATURE_IMPORTANCE_MAPPING; + private static final Map CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING; + static { + Map classImportancePropertiesMapping = new HashMap<>(); + // TODO sorted indices don't support nested types + //classImportancePropertiesMapping.put("dynamic", true); + //classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE); + classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE)); + classImportancePropertiesMapping.put("importance", + Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + Map featureImportancePropertiesMapping = new HashMap<>(); + featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping)); + CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING = + Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); + } + + private static final Map REGRESSION_FEATURE_IMPORTANCE_MAPPING; + static { + Map featureImportancePropertiesMapping = new HashMap<>(); + featureImportancePropertiesMapping.put("importance", + Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName())); + REGRESSION_FEATURE_IMPORTANCE_MAPPING = + Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping)); + } + + static Map regressionFeatureImportanceMapping() { + return REGRESSION_FEATURE_IMPORTANCE_MAPPING; + } + + static Map classificationFeatureImportanceMapping() { + return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING; } private MapUtils() {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index cac6a923160..b4b06187fbe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -247,7 +247,7 @@ public class Regression implements DataFrameAnalysis { @Override public Map getExplicitlyMappedFields(Map mappingsProperties, String resultsFieldName) { Map additionalProperties = new HashMap<>(); - additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping()); + additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping()); // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of // high (over 10M) values of dependent variable. additionalProperties.put(resultsFieldName + "." + predictionFieldName, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java index 1f78ba11e31..3c1a395a1f7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -16,65 +17,74 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; public class FeatureImportance implements Writeable, ToXContentObject { - private final Map classImportance; + private final List classImportance; private final double importance; private final String featureName; static final String IMPORTANCE = "importance"; static final String FEATURE_NAME = "feature_name"; - static final String CLASS_IMPORTANCE = "class_importance"; + static final String CLASSES = "classes"; public static FeatureImportance forRegression(String featureName, double importance) { return new FeatureImportance(featureName, importance, null); } - public static FeatureImportance forClassification(String featureName, Map classImportance) { - return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance); + public static FeatureImportance forClassification(String featureName, List classImportance) { + return new FeatureImportance(featureName, + classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(), + classImportance); } @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("feature_importance", - a -> new FeatureImportance((String) a[0], (Double) a[1], (Map) a[2]) + a -> new FeatureImportance((String) a[0], (Double) a[1], (List) a[2]) ); static { PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); - PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), - new ParseField(FeatureImportance.CLASS_IMPORTANCE)); + PARSER.declareObjectArray(optionalConstructorArg(), + (p, c) -> ClassImportance.fromXContent(p), + new ParseField(FeatureImportance.CLASSES)); } public static FeatureImportance fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - FeatureImportance(String featureName, double importance, Map classImportance) { + FeatureImportance(String featureName, double importance, List classImportance) { this.featureName = Objects.requireNonNull(featureName); this.importance = importance; - this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance); + this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); } public FeatureImportance(StreamInput in) throws IOException { this.featureName = in.readString(); this.importance = in.readDouble(); if (in.readBoolean()) { - this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); + if (in.getVersion().before(Version.V_7_10_0)) { + Map classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); + this.classImportance = ClassImportance.fromMap(classImportance); + } else { + this.classImportance = in.readList(ClassImportance::new); + } } else { this.classImportance = null; } } - public Map getClassImportance() { + public List getClassImportance() { return classImportance; } @@ -92,7 +102,11 @@ public class FeatureImportance implements Writeable, ToXContentObject { out.writeDouble(this.importance); out.writeBoolean(this.classImportance != null); if (this.classImportance != null) { - out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble); + if (out.getVersion().before(Version.V_7_10_0)) { + out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble); + } else { + out.writeList(this.classImportance); + } } } @@ -101,7 +115,7 @@ public class FeatureImportance implements Writeable, ToXContentObject { map.put(FEATURE_NAME, featureName); map.put(IMPORTANCE, importance); if (classImportance != null) { - classImportance.forEach(map::put); + map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList())); } return map; } @@ -112,11 +126,7 @@ public class FeatureImportance implements Writeable, ToXContentObject { builder.field(FEATURE_NAME, featureName); builder.field(IMPORTANCE, importance); if (classImportance != null && classImportance.isEmpty() == false) { - builder.startObject(CLASS_IMPORTANCE); - for (Map.Entry entry : classImportance.entrySet()) { - builder.field(entry.getKey(), entry.getValue()); - } - builder.endObject(); + builder.field(CLASSES, classImportance); } builder.endObject(); return builder; @@ -136,4 +146,92 @@ public class FeatureImportance implements Writeable, ToXContentObject { public int hashCode() { return Objects.hash(featureName, importance, classImportance); } + + public static class ClassImportance implements Writeable, ToXContentObject { + + static final String CLASS_NAME = "class_name"; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("feature_importance_class_importance", + a -> new ClassImportance((String) a[0], (Double) a[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); + PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); + } + + private static ClassImportance fromMapEntry(Map.Entry entry) { + return new ClassImportance(entry.getKey(), entry.getValue()); + } + + private static List fromMap(Map classImportanceMap) { + return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList()); + } + + private static Map toMap(List importances) { + return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance)); + } + + public static ClassImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String className; + private final double importance; + + public ClassImportance(String className, double importance) { + this.className = className; + this.importance = importance; + } + + public ClassImportance(StreamInput in) throws IOException { + this.className = in.readString(); + this.importance = in.readDouble(); + } + + public String getClassName() { + return className; + } + + public double getImportance() { + return importance; + } + + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put(CLASS_NAME, className); + map.put(IMPORTANCE, importance); + return map; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(className); + out.writeDouble(importance); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME, className); + builder.field(IMPORTANCE, importance); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Double.compare(that.importance, importance) == 0 && + Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 5bfa4e054ff..d4cadf33bf4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -15,7 +15,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -139,11 +138,13 @@ public final class InferenceHelpers { if (v.length == 1) { importances.add(FeatureImportance.forRegression(k, v[0])); } else { - Map classImportance = new LinkedHashMap<>(v.length, 1.0f); + List classImportance = new ArrayList<>(v.length); // If the classificationLabels exist, their length must match leaf_value length assert classificationLabels == null || classificationLabels.size() == v.length; for (int i = 0; i < v.length; i++) { - classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]); + classImportance.add(new FeatureImportance.ClassImportance( + classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), + v[i])); } importances.add(FeatureImportance.forClassification(k, classImportance)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 4344b4bfb7b..426018d89c0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -261,12 +261,12 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields( Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")), "results"); @@ -274,7 +274,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase() {{ @@ -289,7 +289,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results"); assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double"))); - assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping())); + assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping())); } public void testGetStateDocId() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index efeb2cdb256..64ca2b1592a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -152,8 +152,15 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing FeatureImportance importance = importanceList.get(i); assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName())); assertThat(objectMap.get("importance"), equalTo(importance.getImportance())); + @SuppressWarnings("unchecked") + List> classImportances = (List>)objectMap.get("classes"); if (importance.getClassImportance() != null) { - importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v))); + for (int j = 0; j < importance.getClassImportance().size(); j++) { + Map classMap = classImportances.get(j); + FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j); + assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName())); + assertThat(classMap.get("importance"), equalTo(classImportance.getImportance())); + } } } } @@ -205,7 +212,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}"; assertEquals(expected, stringRep); - FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap()); + FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0); result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp), Collections.singletonList(fi), config, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java index f23366b1078..6a3563f3a46 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -29,7 +28,8 @@ public class FeatureImportanceTests extends AbstractSerializingTestCase randomAlphaOfLength(10)) .limit(randomLongBetween(2, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false)))); + .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) + .collect(Collectors.toList())); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 91899b688ae..29a40248474 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest String expected = "{\"" + resultsField + "\":1.0}"; assertEquals(expected, stringRep); - FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap()); + FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi)); stringRep = Strings.toString(result); expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";