diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java index d6d0bd4b04f..c83b90fcc15 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java @@ -27,8 +27,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Collections; -import java.util.HashMap; -import java.util.Map; +import java.util.List; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; @@ -38,36 +37,37 @@ public class FeatureImportance implements ToXContentObject { public static final String IMPORTANCE = "importance"; public static final String FEATURE_NAME = "feature_name"; - public static final String CLASS_IMPORTANCE = "class_importance"; + public static final String CLASSES = "classes"; @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("feature_importance", true, - a -> new FeatureImportance((String) a[0], (Double) a[1], (Map) a[2]) + a -> new FeatureImportance((String) a[0], (Double) a[1], (List) a[2]) ); static { PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME)); PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); - PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue), - new ParseField(FeatureImportance.CLASS_IMPORTANCE)); + PARSER.declareObjectArray(optionalConstructorArg(), + (p, c) -> ClassImportance.fromXContent(p), + new ParseField(FeatureImportance.CLASSES)); } public static FeatureImportance fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private final Map classImportance; + private final List classImportance; private final double importance; private final String featureName; - public FeatureImportance(String featureName, double importance, Map classImportance) { + public FeatureImportance(String featureName, double importance, List classImportance) { this.featureName = Objects.requireNonNull(featureName); this.importance = importance; - this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance); + this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance); } - public Map getClassImportance() { + public List getClassImportance() { return classImportance; } @@ -85,11 +85,7 @@ public class FeatureImportance implements ToXContentObject { builder.field(FEATURE_NAME, featureName); builder.field(IMPORTANCE, importance); if (classImportance != null && classImportance.isEmpty() == false) { - builder.startObject(CLASS_IMPORTANCE); - for (Map.Entry entry : classImportance.entrySet()) { - builder.field(entry.getKey(), entry.getValue()); - } - builder.endObject(); + builder.field(CLASSES, classImportance); } builder.endObject(); return builder; @@ -109,4 +105,63 @@ public class FeatureImportance implements ToXContentObject { public int hashCode() { return Objects.hash(featureName, importance, classImportance); } + + public static class ClassImportance implements ToXContentObject { + + static final String CLASS_NAME = "class_name"; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("feature_importance_class_importance", + true, + a -> new ClassImportance((String) a[0], (Double) a[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME)); + PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE)); + } + + public static ClassImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String className; + private final double importance; + + public ClassImportance(String className, double importance) { + this.className = className; + this.importance = importance; + } + + public String getClassName() { + return className; + } + + public double getImportance() { + return importance; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME, className); + builder.field(IMPORTANCE, importance); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Double.compare(that.importance, importance) == 0 && + Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java index 9e6c8492e74..0da86667e1d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java @@ -23,8 +23,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; -import java.util.function.Function; -import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -38,7 +36,8 @@ public class FeatureImportanceTests extends AbstractXContentTestCase randomAlphaOfLength(10)) .limit(randomLongBetween(2, 10)) - .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false)))); + .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false))) + .collect(Collectors.toList())); } @@ -52,8 +51,4 @@ public class FeatureImportanceTests extends AbstractXContentTestCase getRandomFieldsExcludeFilter() { - return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE); - } }