diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java index c83b90fcc15..23c2aa168b3 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java @@ -47,7 +47,7 @@ public class FeatureImportance implements ToXContentObject { static { PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); - PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); + PARSER.declareDouble(optionalConstructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> ClassImportance.fromXContent(p), new ParseField(FeatureImportance.CLASSES)); @@ -58,10 +58,10 @@ public class FeatureImportance implements ToXContentObject { } private final List classImportance; - private final double importance; + private final Double importance; private final String featureName; - public FeatureImportance(String featureName, double importance, List classImportance) { + public FeatureImportance(String featureName, Double importance, List classImportance) { this.featureName = Objects.requireNonNull(featureName); this.importance = importance; this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); @@ -71,7 +71,7 @@ public class FeatureImportance implements ToXContentObject { return classImportance; } - public double getImportance() { + public Double getImportance() { return importance; } @@ -83,7 +83,9 @@ public class FeatureImportance implements ToXContentObject { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(FEATURE_NAME, featureName); - builder.field(IMPORTANCE, importance); + if (importance != null) { + builder.field(IMPORTANCE, importance); + } if (classImportance != null && classImportance.isEmpty() == false) { builder.field(CLASSES, classImportance); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java index 0da86667e1d..dfb3118c4c4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java @@ -32,7 +32,7 @@ public class FeatureImportanceTests extends AbstractXContentTestCase randomAlphaOfLength(10)) .limit(randomLongBetween(2, 10)) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/AbstractFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/AbstractFeatureImportance.java new file mode 100644 index 00000000000..ebf182aeac0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/AbstractFeatureImportance.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +abstract class AbstractFeatureImportance implements Writeable, ToXContentObject { + + public abstract String getFeatureName(); + + public abstract Map toMap(); + + @Override + public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.map(toMap()); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java similarity index 51% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java index 0846acf3331..7eff392dabe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.Version; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -26,157 +25,101 @@ import java.util.stream.Collectors; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; -public class FeatureImportance implements Writeable, ToXContentObject { +public class ClassificationFeatureImportance extends AbstractFeatureImportance { private final List classImportance; - private final double importance; private final String featureName; - static final String IMPORTANCE = "importance"; + static final String FEATURE_NAME = "feature_name"; static final String CLASSES = "classes"; - public static FeatureImportance forRegression(String featureName, double importance) { - return new FeatureImportance(featureName, importance, null); - } - - public static FeatureImportance forBinaryClassification(String featureName, double importance, List classImportance) { - return new FeatureImportance(featureName, - importance, - classImportance); - } - - public static FeatureImportance forClassification(String featureName, List classImportance) { - return new FeatureImportance(featureName, - classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(), - classImportance); - } - @SuppressWarnings("unchecked") - private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("feature_importance", - a -> new FeatureImportance((String) a[0], (Double) a[1], (List) a[2]) + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("classification_feature_importance", + a -> new ClassificationFeatureImportance((String) a[0], (List) a[1]) ); static { - PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); - PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); + PARSER.declareString(constructorArg(), new ParseField(ClassificationFeatureImportance.FEATURE_NAME)); PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> ClassImportance.fromXContent(p), - new ParseField(FeatureImportance.CLASSES)); + new ParseField(ClassificationFeatureImportance.CLASSES)); } - public static FeatureImportance fromXContent(XContentParser parser) { + public static ClassificationFeatureImportance fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - FeatureImportance(String featureName, double importance, List classImportance) { + public ClassificationFeatureImportance(String featureName, List classImportance) { this.featureName = Objects.requireNonNull(featureName); - this.importance = importance; - this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); + this.classImportance = classImportance == null ? Collections.emptyList() : Collections.unmodifiableList(classImportance); } - public FeatureImportance(StreamInput in) throws IOException { + public ClassificationFeatureImportance(StreamInput in) throws IOException { this.featureName = in.readString(); - this.importance = in.readDouble(); - if (in.readBoolean()) { - if (in.getVersion().before(Version.V_7_10_0)) { - Map classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); - this.classImportance = ClassImportance.fromMap(classImportance); - } else { - this.classImportance = in.readList(ClassImportance::new); - } - } else { - this.classImportance = null; - } + this.classImportance = in.readList(ClassImportance::new); } public List getClassImportance() { return classImportance; } - public double getImportance() { - return importance; - } - + @Override public String getFeatureName() { return featureName; } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(this.featureName); - out.writeDouble(this.importance); - out.writeBoolean(this.classImportance != null); - if (this.classImportance != null) { - if (out.getVersion().before(Version.V_7_10_0)) { - out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble); - } else { - out.writeList(this.classImportance); - } + public double getTotalImportance() { + if (classImportance.size() == 2) { + // Binary classification. We can return the first class importance here + return Math.abs(classImportance.get(0).getImportance()); } + return classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(); } + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureName); + out.writeList(classImportance); + } + + @Override public Map toMap() { Map map = new LinkedHashMap<>(); map.put(FEATURE_NAME, featureName); - map.put(IMPORTANCE, importance); - if (classImportance != null) { + if (classImportance.isEmpty() == false) { map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList())); } return map; } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(FEATURE_NAME, featureName); - builder.field(IMPORTANCE, importance); - if (classImportance != null && classImportance.isEmpty() == false) { - builder.field(CLASSES, classImportance); - } - builder.endObject(); - return builder; - } - @Override public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } - FeatureImportance that = (FeatureImportance) object; + ClassificationFeatureImportance that = (ClassificationFeatureImportance) object; return Objects.equals(featureName, that.featureName) - && Objects.equals(importance, that.importance) && Objects.equals(classImportance, that.classImportance); } @Override public int hashCode() { - return Objects.hash(featureName, importance, classImportance); + return Objects.hash(featureName, classImportance); } public static class ClassImportance implements Writeable, ToXContentObject { static final String CLASS_NAME = "class_name"; + static final String IMPORTANCE = "importance"; private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("feature_importance_class_importance", - a -> new ClassImportance((String) a[0], (Double) a[1]) + new ConstructingObjectParser<>("classification_feature_importance_class_importance", + a -> new ClassImportance(a[0], (Double) a[1]) ); static { PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); - PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); - } - - private static ClassImportance fromMapEntry(Map.Entry entry) { - return new ClassImportance(entry.getKey(), entry.getValue()); - } - - private static List fromMap(Map classImportanceMap) { - return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList()); - } - - private static Map toMap(List importances) { - return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance)); + PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE)); } public static ClassImportance fromXContent(XContentParser parser) { @@ -219,11 +162,7 @@ public class FeatureImportance implements Writeable, ToXContentObject { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME, className); - builder.field(IMPORTANCE, importance); - builder.endObject(); - return builder; + return builder.map(toMap()); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 0f846203f86..ae780234caa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -15,9 +15,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldTyp import java.io.IOException; import java.util.Collections; -import java.util.Map; import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; @@ -34,12 +34,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults private final Double predictionProbability; private final Double predictionScore; private final List topClasses; + private final List featureImportance; private final PredictionFieldType predictionFieldType; public ClassificationInferenceResults(double value, String classificationLabel, List topClasses, - List featureImportance, + List featureImportance, InferenceConfig config, Double predictionProbability, Double predictionScore) { @@ -55,13 +56,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults private ClassificationInferenceResults(double value, String classificationLabel, List topClasses, - List featureImportance, + List featureImportance, ClassificationConfig classificationConfig, Double predictionProbability, Double predictionScore) { - super(value, - SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, - classificationConfig.getNumTopFeatureImportanceValues())); + super(value); this.classificationLabel = classificationLabel; this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); this.topNumClassesField = classificationConfig.getTopClassesResultsField(); @@ -69,10 +68,32 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults this.predictionFieldType = classificationConfig.getPredictionFieldType(); this.predictionProbability = predictionProbability; this.predictionScore = predictionScore; + this.featureImportance = takeTopFeatureImportances(featureImportance, classificationConfig.getNumTopFeatureImportanceValues()); + } + + static List takeTopFeatureImportances(List featureImportances, + int numTopFeatures) { + if (featureImportances == null || featureImportances.isEmpty()) { + return Collections.emptyList(); + } + return featureImportances.stream() + .sorted((l, r)-> Double.compare(r.getTotalImportance(), l.getTotalImportance())) + .limit(numTopFeatures) + .collect(Collectors.toList()); } public ClassificationInferenceResults(StreamInput in) throws IOException { super(in); + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + this.featureImportance = in.readList(ClassificationFeatureImportance::new); + } else if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + this.featureImportance = in.readList(LegacyFeatureImportance::new) + .stream() + .map(LegacyFeatureImportance::forClassification) + .collect(Collectors.toList()); + } else { + this.featureImportance = Collections.emptyList(); + } this.classificationLabel = in.readOptionalString(); this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); this.topNumClassesField = in.readString(); @@ -103,9 +124,18 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults return predictionFieldType; } + public List getFeatureImportance() { + return featureImportance; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeList(featureImportance); + } else if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromClassification).collect(Collectors.toList())); + } out.writeOptionalString(classificationLabel); out.writeCollection(topClasses); out.writeString(topNumClassesField); @@ -132,7 +162,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults && Objects.equals(predictionFieldType, that.predictionFieldType) && Objects.equals(predictionProbability, that.predictionProbability) && Objects.equals(predictionScore, that.predictionScore) - && Objects.equals(getFeatureImportance(), that.getFeatureImportance()); + && Objects.equals(featureImportance, that.featureImportance); } @Override @@ -144,7 +174,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults topNumClassesField, predictionProbability, predictionScore, - getFeatureImportance(), + featureImportance, predictionFieldType); } @@ -179,8 +209,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults if (predictionScore != null) { map.put(PREDICTION_SCORE, predictionScore); } - if (getFeatureImportance().isEmpty() == false) { - map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList())); + if (featureImportance.isEmpty() == false) { + map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(ClassificationFeatureImportance::toMap) + .collect(Collectors.toList())); } return map; } @@ -202,8 +233,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults if (predictionScore != null) { builder.field(PREDICTION_SCORE, predictionScore); } - if (getFeatureImportance().size() > 0) { - builder.field(FEATURE_IMPORTANCE, getFeatureImportance()); + if (featureImportance.isEmpty() == false) { + builder.field(FEATURE_IMPORTANCE, featureImportance); } return builder; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportance.java new file mode 100644 index 00000000000..7881f3edaab --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportance.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * This class captures serialization of feature importance for + * classification and regression prior to version 7.10. + */ +public class LegacyFeatureImportance implements Writeable { + + public static LegacyFeatureImportance fromClassification(ClassificationFeatureImportance classificationFeatureImportance) { + return new LegacyFeatureImportance( + classificationFeatureImportance.getFeatureName(), + classificationFeatureImportance.getTotalImportance(), + classificationFeatureImportance.getClassImportance().stream().map(classImportance -> new ClassImportance( + classImportance.getClassName(), classImportance.getImportance())).collect(Collectors.toList()) + ); + } + + public static LegacyFeatureImportance fromRegression(RegressionFeatureImportance regressionFeatureImportance) { + return new LegacyFeatureImportance( + regressionFeatureImportance.getFeatureName(), + regressionFeatureImportance.getImportance(), + null + ); + } + + private final List classImportance; + private final double importance; + private final String featureName; + + LegacyFeatureImportance(String featureName, double importance, List classImportance) { + this.featureName = Objects.requireNonNull(featureName); + this.importance = importance; + this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); + } + + public LegacyFeatureImportance(StreamInput in) throws IOException { + this.featureName = in.readString(); + this.importance = in.readDouble(); + if (in.readBoolean()) { + if (in.getVersion().before(Version.V_7_10_0)) { + Map classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble); + this.classImportance = ClassImportance.fromMap(classImportance); + } else { + this.classImportance = in.readList(ClassImportance::new); + } + } else { + this.classImportance = null; + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureName); + out.writeDouble(importance); + out.writeBoolean(classImportance != null); + if (classImportance != null) { + if (out.getVersion().before(Version.V_7_10_0)) { + out.writeMap(ClassImportance.toMap(classImportance), StreamOutput::writeString, StreamOutput::writeDouble); + } else { + out.writeList(classImportance); + } + } + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + LegacyFeatureImportance that = (LegacyFeatureImportance) object; + return Objects.equals(featureName, that.featureName) + && Objects.equals(importance, that.importance) + && Objects.equals(classImportance, that.classImportance); + } + + @Override + public int hashCode() { + return Objects.hash(featureName, importance, classImportance); + } + + public RegressionFeatureImportance forRegression() { + assert classImportance == null; + return new RegressionFeatureImportance(featureName, importance); + } + + public ClassificationFeatureImportance forClassification() { + assert classImportance != null; + return new ClassificationFeatureImportance(featureName, classImportance.stream().map( + aClassImportance -> new ClassificationFeatureImportance.ClassImportance( + aClassImportance.className, aClassImportance.importance)).collect(Collectors.toList())); + } + + public static class ClassImportance implements Writeable { + + private static ClassImportance fromMapEntry(Map.Entry entry) { + return new ClassImportance(entry.getKey(), entry.getValue()); + } + + private static List fromMap(Map classImportanceMap) { + return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList()); + } + + private static Map toMap(List importances) { + return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance)); + } + + private final Object className; + private final double importance; + + public ClassImportance(Object className, double importance) { + this.className = className; + this.importance = importance; + } + + public ClassImportance(StreamInput in) throws IOException { + this.className = in.readGenericValue(); + this.importance = in.readDouble(); + } + + double getImportance() { + return importance; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeGenericValue(className); + out.writeDouble(importance); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Double.compare(that.importance, importance) == 0 && + Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportance.java new file mode 100644 index 00000000000..d72a0ccaf3b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportance.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class RegressionFeatureImportance extends AbstractFeatureImportance { + + private final double importance; + private final String featureName; + static final String IMPORTANCE = "importance"; + static final String FEATURE_NAME = "feature_name"; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("regression_feature_importance", + a -> new RegressionFeatureImportance((String) a[0], (Double) a[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField(RegressionFeatureImportance.FEATURE_NAME)); + PARSER.declareDouble(constructorArg(), new ParseField(RegressionFeatureImportance.IMPORTANCE)); + } + + public static RegressionFeatureImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public RegressionFeatureImportance(String featureName, double importance) { + this.featureName = Objects.requireNonNull(featureName); + this.importance = importance; + } + + public RegressionFeatureImportance(StreamInput in) throws IOException { + this.featureName = in.readString(); + this.importance = in.readDouble(); + } + + public double getImportance() { + return importance; + } + + @Override + public String getFeatureName() { + return featureName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureName); + out.writeDouble(importance); + } + + @Override + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put(FEATURE_NAME, featureName); + map.put(IMPORTANCE, importance); + return map; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RegressionFeatureImportance that = (RegressionFeatureImportance) object; + return Objects.equals(featureName, that.featureName) + && Objects.equals(importance, that.importance); + } + + @Override + public int hashCode() { + return Objects.hash(featureName, importance); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index 498fd2828bc..0761337505f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -24,14 +25,19 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { public static final String NAME = "regression"; private final String resultsField; + private final List featureImportance; public RegressionInferenceResults(double value, InferenceConfig config) { this(value, config, Collections.emptyList()); } - public RegressionInferenceResults(double value, InferenceConfig config, List featureImportance) { - this(value, ((RegressionConfig)config).getResultsField(), - ((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance); + public RegressionInferenceResults(double value, InferenceConfig config, List featureImportance) { + this( + value, + ((RegressionConfig)config).getResultsField(), + ((RegressionConfig)config).getNumTopFeatureImportanceValues(), + featureImportance + ); } public RegressionInferenceResults(double value, String resultsField) { @@ -39,28 +45,58 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { } public RegressionInferenceResults(double value, String resultsField, - List featureImportance) { + List featureImportance) { this(value, resultsField, featureImportance.size(), featureImportance); } public RegressionInferenceResults(double value, String resultsField, int topNFeatures, - List featureImportance) { - super(value, - SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures)); + List featureImportance) { + super(value); this.resultsField = resultsField; + this.featureImportance = takeTopFeatureImportances(featureImportance, topNFeatures); + } + + static List takeTopFeatureImportances(List featureImportances, + int numTopFeatures) { + if (featureImportances == null || featureImportances.isEmpty()) { + return Collections.emptyList(); + } + return featureImportances.stream() + .sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance()))) + .limit(numTopFeatures) + .collect(Collectors.toList()); } public RegressionInferenceResults(StreamInput in) throws IOException { super(in); + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + this.featureImportance = in.readList(RegressionFeatureImportance::new); + } else if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + this.featureImportance = in.readList(LegacyFeatureImportance::new) + .stream() + .map(LegacyFeatureImportance::forRegression) + .collect(Collectors.toList()); + } else { + this.featureImportance = Collections.emptyList(); + } this.resultsField = in.readString(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + if (out.getVersion().onOrAfter(Version.V_7_10_0)) { + out.writeList(featureImportance); + } else if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromRegression).collect(Collectors.toList())); + } out.writeString(resultsField); } + public List getFeatureImportance() { + return featureImportance; + } + @Override public boolean equals(Object object) { if (object == this) { return true; } @@ -68,12 +104,12 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { RegressionInferenceResults that = (RegressionInferenceResults) object; return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField) - && Objects.equals(this.getFeatureImportance(), that.getFeatureImportance()); + && Objects.equals(this.featureImportance, that.featureImportance); } @Override public int hashCode() { - return Objects.hash(value(), resultsField, getFeatureImportance()); + return Objects.hash(value(), resultsField, featureImportance); } @Override @@ -85,8 +121,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { public Map asMap() { Map map = new LinkedHashMap<>(); map.put(resultsField, value()); - if (getFeatureImportance().isEmpty() == false) { - map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList())); + if (featureImportance.isEmpty() == false) { + map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList())); } return map; } @@ -94,8 +130,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(resultsField, value()); - if (getFeatureImportance().size() > 0) { - builder.field(FEATURE_IMPORTANCE, getFeatureImportance()); + if (featureImportance.isEmpty() == false) { + builder.field(FEATURE_IMPORTANCE, featureImportance); } return builder; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java index 25ede9955be..18c651f1278 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -5,53 +5,30 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; public abstract class SingleValueInferenceResults implements InferenceResults { public static final String FEATURE_IMPORTANCE = "feature_importance"; private final double value; - private final List featureImportance; - static List takeTopFeatureImportances(List unsortedFeatureImportances, int numTopFeatures) { - if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) { - return unsortedFeatureImportances; - } - return unsortedFeatureImportances.stream() - .sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance()))) - .limit(numTopFeatures) - .collect(Collectors.toList()); - } SingleValueInferenceResults(StreamInput in) throws IOException { value = in.readDouble(); - if (in.getVersion().onOrAfter(Version.V_7_7_0)) { - this.featureImportance = in.readList(FeatureImportance::new); - } else { - this.featureImportance = Collections.emptyList(); - } } - SingleValueInferenceResults(double value, List featureImportance) { + SingleValueInferenceResults(double value) { this.value = value; - this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance; } public Double value() { return value; } - public List getFeatureImportance() { - return featureImportance; - } public String valueAsString() { return String.valueOf(value); @@ -60,9 +37,6 @@ public abstract class SingleValueInferenceResults implements InferenceResults { @Override public void writeTo(StreamOutput out) throws IOException { out.writeDouble(value); - if (out.getVersion().onOrAfter(Version.V_7_7_0)) { - out.writeList(this.featureImportance); - } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 0b5bf658cb1..ed82a4da476 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -7,7 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.collect.Tuple; -import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -130,17 +131,18 @@ public final class InferenceHelpers { return originalFeatureImportance; } - public static List transformFeatureImportanceRegression(Map featureImportance) { - List importances = new ArrayList<>(featureImportance.size()); - featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0]))); + public static List transformFeatureImportanceRegression(Map featureImportance) { + List importances = new ArrayList<>(featureImportance.size()); + featureImportance.forEach((k, v) -> importances.add(new RegressionFeatureImportance(k, v[0]))); return importances; } - public static List transformFeatureImportanceClassification(Map featureImportance, - final int predictedValue, - @Nullable List classificationLabels, - @Nullable PredictionFieldType predictionFieldType) { - List importances = new ArrayList<>(featureImportance.size()); + public static List transformFeatureImportanceClassification( + Map featureImportance, + final int predictedValue, + @Nullable List classificationLabels, + @Nullable PredictionFieldType predictionFieldType) { + List importances = new ArrayList<>(featureImportance.size()); final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType; featureImportance.forEach((k, v) -> { // This indicates logistic regression (binary classification) @@ -152,27 +154,26 @@ public final class InferenceHelpers { final int otherClass = 1 - predictedValue; String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue); String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass); - importances.add(FeatureImportance.forBinaryClassification(k, - v[0], + importances.add(new ClassificationFeatureImportance(k, Arrays.asList( - new FeatureImportance.ClassImportance( + new ClassificationFeatureImportance.ClassImportance( fieldType.transformPredictedValue((double)predictedValue, predictedLabel), v[0]), - new FeatureImportance.ClassImportance( + new ClassificationFeatureImportance.ClassImportance( fieldType.transformPredictedValue((double)otherClass, otherLabel), -v[0]) ))); } else { - List classImportance = new ArrayList<>(v.length); + List classImportance = new ArrayList<>(v.length); // If the classificationLabels exist, their length must match leaf_value length assert classificationLabels == null || classificationLabels.size() == v.length; for (int i = 0; i < v.length; i++) { String label = classificationLabels == null ? null : classificationLabels.get(i); - classImportance.add(new FeatureImportance.ClassImportance( + classImportance.add(new ClassificationFeatureImportance.ClassImportance( fieldType.transformPredictedValue((double)i, label), v[i])); } - importances.add(FeatureImportance.forClassification(k, classImportance)); + importances.add(new ClassificationFeatureImportance(k, classImportance)); } }); return importances; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java new file mode 100644 index 00000000000..6ef314cfe2c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class ClassificationFeatureImportanceTests extends AbstractSerializingTestCase { + + @Override + protected ClassificationFeatureImportance doParseInstance(XContentParser parser) throws IOException { + return ClassificationFeatureImportance.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationFeatureImportance::new; + } + + @Override + protected ClassificationFeatureImportance createTestInstance() { + return createRandomInstance(); + } + + public static ClassificationFeatureImportance createRandomInstance() { + return new ClassificationFeatureImportance( + randomAlphaOfLength(10), + Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomLongBetween(2, 10)) + .map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) + .collect(Collectors.toList())); + } + + public void testGetTotalImportance_GivenBinary() { + ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance( + "binary", + Arrays.asList( + new ClassificationFeatureImportance.ClassImportance("a", 0.15), + new ClassificationFeatureImportance.ClassImportance("not-a", -0.15) + ) + ); + + assertThat(featureImportance.getTotalImportance(), equalTo(0.15)); + } + + public void testGetTotalImportance_GivenMulticlass() { + ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance( + "multiclass", + Arrays.asList( + new ClassificationFeatureImportance.ClassImportance("a", 0.15), + new ClassificationFeatureImportance.ClassImportance("b", -0.05), + new ClassificationFeatureImportance.ClassImportance("c", 0.30) + ) + ); + + assertThat(featureImportance.getTotalImportance(), closeTo(0.50, 0.00000001)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index 64ca2b1592a..80d6c880a88 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -18,7 +18,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -29,10 +28,6 @@ import static org.hamcrest.Matchers.hasSize; public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase { public static ClassificationInferenceResults createRandomResults() { - Supplier featureImportanceCtor = randomBoolean() ? - FeatureImportanceTests::randomClassification : - FeatureImportanceTests::randomRegression; - ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig(); Double value = randomDouble(); if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) { @@ -47,7 +42,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing .limit(randomIntBetween(0, 10)) .collect(Collectors.toList()), randomBoolean() ? null : - Stream.generate(featureImportanceCtor) + Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance) .limit(randomIntBetween(1, 10)) .collect(Collectors.toList()), config, @@ -123,11 +118,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing } public void testWriteResultsWithImportance() { - Supplier featureImportanceCtor = randomBoolean() ? - FeatureImportanceTests::randomClassification : - FeatureImportanceTests::randomRegression; - - List importanceList = Stream.generate(featureImportanceCtor) + List importanceList = Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance) .limit(5) .collect(Collectors.toList()); ClassificationInferenceResults result = new ClassificationInferenceResults(0.0, @@ -146,18 +137,17 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing "result_field.feature_importance", List.class); assertThat(writtenImportance, hasSize(3)); - importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance()))); + importanceList.sort((l, r) -> Double.compare(Math.abs(r.getTotalImportance()), Math.abs(l.getTotalImportance()))); for (int i = 0; i < 3; i++) { Map objectMap = writtenImportance.get(i); - FeatureImportance importance = importanceList.get(i); + ClassificationFeatureImportance importance = importanceList.get(i); assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName())); - assertThat(objectMap.get("importance"), equalTo(importance.getImportance())); @SuppressWarnings("unchecked") List> classImportances = (List>)objectMap.get("classes"); if (importance.getClassImportance() != null) { for (int j = 0; j < importance.getClassImportance().size(); j++) { Map classMap = classImportances.get(j); - FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j); + ClassificationFeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j); assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName())); assertThat(classMap.get("importance"), equalTo(classImportance.getImportance())); } @@ -212,7 +202,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}"; assertEquals(expected, stringRep); - FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); + ClassificationFeatureImportance fi = new ClassificationFeatureImportance("foo", Collections.emptyList()); TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0); result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp), Collections.singletonList(fi), config, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java deleted file mode 100644 index 6a3563f3a46..00000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.inference.results; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractSerializingTestCase; - -import java.io.IOException; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class FeatureImportanceTests extends AbstractSerializingTestCase { - - public static FeatureImportance createRandomInstance() { - return randomBoolean() ? randomClassification() : randomRegression(); - } - - static FeatureImportance randomRegression() { - return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false)); - } - - static FeatureImportance randomClassification() { - return FeatureImportance.forClassification( - randomAlphaOfLength(10), - Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomLongBetween(2, 10)) - .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) - .collect(Collectors.toList())); - } - - @Override - protected FeatureImportance createTestInstance() { - return createRandomInstance(); - } - - @Override - protected Writeable.Reader instanceReader() { - return FeatureImportance::new; - } - - @Override - protected FeatureImportance doParseInstance(XContentParser parser) throws IOException { - return FeatureImportance.fromXContent(parser); - } -} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportanceTests.java new file mode 100644 index 00000000000..9100be72734 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportanceTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class LegacyFeatureImportanceTests extends AbstractWireSerializingTestCase { + + public static LegacyFeatureImportance createRandomInstance() { + return createRandomInstance(randomBoolean()); + } + + public static LegacyFeatureImportance createRandomInstance(boolean hasClasses) { + double importance = randomDouble(); + List classImportances = null; + if (hasClasses) { + classImportances = Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomLongBetween(2, 10)) + .map(featureName -> new LegacyFeatureImportance.ClassImportance(featureName, randomDouble())) + .collect(Collectors.toList()); + + importance = classImportances.stream().mapToDouble(LegacyFeatureImportance.ClassImportance::getImportance).map(Math::abs).sum(); + } + return new LegacyFeatureImportance(randomAlphaOfLength(10), importance, classImportances); + } + + @Override + protected LegacyFeatureImportance createTestInstance() { + return createRandomInstance(); + } + + @Override + protected Writeable.Reader instanceReader() { + return LegacyFeatureImportance::new; + } + + public void testClassificationConversion() { + { + ClassificationFeatureImportance classificationFeatureImportance = ClassificationFeatureImportanceTests.createRandomInstance(); + LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromClassification(classificationFeatureImportance); + ClassificationFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forClassification(); + assertThat(convertedFeatureImportance, equalTo(classificationFeatureImportance)); + } + { + LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(true); + ClassificationFeatureImportance classificationFeatureImportance = legacyFeatureImportance.forClassification(); + LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromClassification( + classificationFeatureImportance); + assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance)); + } + } + + public void testRegressionConversion() { + { + RegressionFeatureImportance regressionFeatureImportance = RegressionFeatureImportanceTests.createRandomInstance(); + LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance); + RegressionFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forRegression(); + assertThat(convertedFeatureImportance, equalTo(regressionFeatureImportance)); + } + { + LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(false); + RegressionFeatureImportance regressionFeatureImportance = legacyFeatureImportance.forRegression(); + LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance); + assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance)); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportanceTests.java new file mode 100644 index 00000000000..8bb85d76cbb --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportanceTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class RegressionFeatureImportanceTests extends AbstractSerializingTestCase { + + @Override + protected RegressionFeatureImportance doParseInstance(XContentParser parser) throws IOException { + return RegressionFeatureImportance.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionFeatureImportance::new; + } + + @Override + protected RegressionFeatureImportance createTestInstance() { + return createRandomInstance(); + } + + public static RegressionFeatureImportance createRandomInstance() { + return new RegressionFeatureImportance(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 29a40248474..e6f865ab59f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -29,8 +29,8 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest public static RegressionInferenceResults createRandomResults() { return new RegressionInferenceResults(randomDouble(), RegressionConfigTests.randomRegressionConfig(), - randomBoolean() ? null : - Stream.generate(FeatureImportanceTests::randomRegression) + randomBoolean() ? Collections.emptyList() : + Stream.generate(RegressionFeatureImportanceTests::createRandomInstance) .limit(randomIntBetween(1, 10)) .collect(Collectors.toList())); } @@ -50,7 +50,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest } public void testWriteResultsWithImportance() { - List importanceList = Stream.generate(FeatureImportanceTests::randomRegression) + List importanceList = Stream.generate(RegressionFeatureImportanceTests::createRandomInstance) .limit(5) .collect(Collectors.toList()); RegressionInferenceResults result = new RegressionInferenceResults(0.3, @@ -68,7 +68,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance()))); for (int i = 0; i < 3; i++) { Map objectMap = writtenImportance.get(i); - FeatureImportance importance = importanceList.get(i); + RegressionFeatureImportance importance = importanceList.get(i); assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName())); assertThat(objectMap.get("importance"), equalTo(importance.getImportance())); assertThat(objectMap.size(), equalTo(2)); @@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest String expected = "{\"" + resultsField + "\":1.0}"; assertEquals(expected, stringRep); - FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList()); + RegressionFeatureImportance fi = new RegressionFeatureImportance("foo", 1.0); result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi)); stringRep = Strings.toString(result); expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java index 6ecd7a8e212..728f11d215d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java @@ -16,8 +16,8 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import java.io.IOException; @@ -134,9 +134,9 @@ public class InferenceDefinitionTests extends ESTestCase { ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config); assertThat(results.valueAsString(), equalTo("second")); assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2")); - assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001)); + assertThat(results.getFeatureImportance().get(0).getTotalImportance(), closeTo(0.944, 0.001)); assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1")); - assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001)); + assertThat(results.getFeatureImportance().get(1).getTotalImportance(), closeTo(0.199, 0.001)); } public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException { @@ -155,20 +155,20 @@ public class InferenceDefinitionTests extends ESTestCase { ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config); assertThat(results.valueAsString(), equalTo("second")); - FeatureImportance featureImportance1 = results.getFeatureImportance().get(0); + ClassificationFeatureImportance featureImportance1 = results.getFeatureImportance().get(0); assertThat(featureImportance1.getFeatureName(), equalTo("col2")); - assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001)); - for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) { + assertThat(featureImportance1.getTotalImportance(), closeTo(0.944, 0.001)); + for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) { if (classImportance.getClassName().equals("second")) { assertThat(classImportance.getImportance(), closeTo(0.944, 0.001)); } else { assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001)); } } - FeatureImportance featureImportance2 = results.getFeatureImportance().get(1); + ClassificationFeatureImportance featureImportance2 = results.getFeatureImportance().get(1); assertThat(featureImportance2.getFeatureName(), equalTo("col1_male")); - assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001)); - for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) { + assertThat(featureImportance2.getTotalImportance(), closeTo(0.199, 0.001)); + for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) { if (classImportance.getClassName().equals("second")) { assertThat(classImportance.getImportance(), closeTo(0.199, 0.001)); } else { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java index 04f8aa7e28c..093dfd277ea 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java @@ -16,10 +16,11 @@ import org.elasticsearch.search.aggregations.InvalidAggregationPathException; import org.elasticsearch.search.aggregations.ParsedAggregation; import org.elasticsearch.test.InternalAggregationTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; @@ -115,7 +116,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa } else if (result instanceof RegressionInferenceResults) { RegressionInferenceResults regression = (RegressionInferenceResults) result; assertEquals(regression.value(), parsed.getValue()); - List featureImportance = regression.getFeatureImportance(); + List featureImportance = regression.getFeatureImportance(); if (featureImportance.isEmpty()) { featureImportance = null; } @@ -124,7 +125,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa ClassificationInferenceResults classification = (ClassificationInferenceResults) result; assertEquals(classification.predictedValue(), parsed.getValue()); - List featureImportance = classification.getFeatureImportance(); + List featureImportance = classification.getFeatureImportance(); if (featureImportance.isEmpty()) { featureImportance = null; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java index 7ea22f6365a..fd74103ec4b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.aggregations.ParsedAggregation; -import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -21,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf import java.io.IOException; import java.util.List; +import java.util.Map; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY; @@ -45,7 +45,7 @@ public class ParsedInference extends ParsedAggregation { @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true, - args -> new ParsedInference(args[0], (List) args[1], + args -> new ParsedInference(args[0], (List>) args[1], (List) args[2], (String) args[3], (Double) args[4], (Double) args[5])); static { @@ -65,7 +65,7 @@ public class ParsedInference extends ParsedAggregation { } return o; }, CommonFields.VALUE, ObjectParser.ValueType.VALUE); - PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> p.map(), new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE)); PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD)); @@ -82,14 +82,14 @@ public class ParsedInference extends ParsedAggregation { } private final Object value; - private final List featureImportance; + private final List> featureImportance; private final List topClasses; private final String warning; private final Double predictionProbability; private final Double predictionScore; ParsedInference(Object value, - List featureImportance, + List> featureImportance, List topClasses, String warning, Double predictionProbability, @@ -106,7 +106,7 @@ public class ParsedInference extends ParsedAggregation { return value; } - public List getFeatureImportance() { + public List> getFeatureImportance() { return featureImportance; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 5fa4ea21d45..3017e311ea2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -9,8 +9,9 @@ import org.elasticsearch.client.Client; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -136,9 +137,11 @@ public class InferenceProcessorTests extends ESTestCase { classes.add(new TopClassEntry("foo", 0.6, 0.6)); classes.add(new TopClassEntry("bar", 0.4, 0.4)); - List featureInfluence = new ArrayList<>(); - featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13)); - featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0)); + List featureInfluence = new ArrayList<>(); + featureInfluence.add(new ClassificationFeatureImportance("feature_1", + Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_a", 1.13)))); + featureInfluence.add(new ClassificationFeatureImportance("feature_2", + Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_b", -42.0)))); InternalInferModelAction.Response response = new InternalInferModelAction.Response( Collections.singletonList(new ClassificationInferenceResults(1.0, @@ -153,10 +156,12 @@ public class InferenceProcessorTests extends ESTestCase { assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model")); assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo")); - assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0)); assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2")); - assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13)); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.class_name", String.class), equalTo("class_b")); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.importance", Double.class), equalTo(-42.0)); assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1")); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.class_name", String.class), equalTo("class_a")); + assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.importance", Double.class), equalTo(1.13)); } @SuppressWarnings("unchecked") @@ -234,9 +239,9 @@ public class InferenceProcessorTests extends ESTestCase { Map ingestMetadata = new HashMap<>(); IngestDocument document = new IngestDocument(source, ingestMetadata); - List featureInfluence = new ArrayList<>(); - featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13)); - featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0)); + List featureInfluence = new ArrayList<>(); + featureInfluence.add(new RegressionFeatureImportance("feature_1", 1.13)); + featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);