diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java index 882dc046d6d..7f981c8327c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -40,6 +40,7 @@ public class TotalFeatureImportance implements ToXContentObject { public static final ParseField IMPORTANCE = new ParseField("importance"); public static final ParseField CLASSES = new ParseField("classes"); public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); + public static final ParseField BASELINE = new ParseField("baseline"); public static final ParseField MIN = new ParseField("min"); public static final ParseField MAX = new ParseField("max"); @@ -102,22 +103,25 @@ public class TotalFeatureImportance implements ToXContentObject { public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, - a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3])); static { PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN); PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE); } private final double meanMagnitude; private final double min; private final double max; + private final Double baseline; - public Importance(double meanMagnitude, double min, double max) { + public Importance(double meanMagnitude, double min, double max, Double baseline) { this.meanMagnitude = meanMagnitude; this.min = min; this.max = max; + this.baseline = baseline; } @Override @@ -127,12 +131,13 @@ public class TotalFeatureImportance implements ToXContentObject { Importance that = (Importance) o; return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && Double.compare(that.min, min) == 0 && - Double.compare(that.max, max) == 0; + Double.compare(that.max, max) == 0 && + Objects.equals(that.baseline, baseline); } @Override public int hashCode() { - return Objects.hash(meanMagnitude, min, max); + return Objects.hash(meanMagnitude, min, max, baseline); } @Override @@ -141,6 +146,9 @@ public class TotalFeatureImportance implements ToXContentObject { builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); builder.field(MIN.getPreferredName(), min); builder.field(MAX.getPreferredName(), max); + if (baseline != null) { + builder.field(BASELINE.getPreferredName(), baseline); + } builder.endObject(); return builder; } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java index adbf9ab052d..5f185df6e6a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -50,7 +50,11 @@ public class TotalFeatureImportanceTests extends AbstractXContentTestCase LENIENT_PARSER = createParser(true); @@ -124,27 +125,31 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, - a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3])); parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN); parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE); return parser; } private final double meanMagnitude; private final double min; private final double max; + private final Double baseline; - public Importance(double meanMagnitude, double min, double max) { + public Importance(double meanMagnitude, double min, double max, Double baseline) { this.meanMagnitude = meanMagnitude; this.min = min; this.max = max; + this.baseline = baseline; } public Importance(StreamInput in) throws IOException { this.meanMagnitude = in.readDouble(); this.min = in.readDouble(); this.max = in.readDouble(); + this.baseline = in.readOptionalDouble(); } @Override @@ -154,12 +159,13 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { Importance that = (Importance) o; return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && Double.compare(that.min, min) == 0 && - Double.compare(that.max, max) == 0; + Double.compare(that.max, max) == 0 && + Objects.equals(that.baseline, baseline); } @Override public int hashCode() { - return Objects.hash(meanMagnitude, min, max); + return Objects.hash(meanMagnitude, min, max, baseline); } @Override @@ -167,6 +173,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { out.writeDouble(meanMagnitude); out.writeDouble(min); out.writeDouble(max); + out.writeOptionalDouble(baseline); } @Override @@ -179,6 +186,9 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); map.put(MIN.getPreferredName(), min); map.put(MAX.getPreferredName(), max); + if (baseline != null) { + map.put(BASELINE.getPreferredName(), baseline); + } return map; } } diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json index 00f5eb2a90f..f5fb2768a8d 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json @@ -85,6 +85,9 @@ }, "mean_magnitude": { "type": "double" + }, + "baseline": { + "type": "double" } } }, @@ -105,6 +108,9 @@ }, "mean_magnitude": { "type": "double" + }, + "baseline": { + "type": "double" } } }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java index fa68e71e8cc..ea5ccde3b9c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -41,7 +41,11 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas } private static TotalFeatureImportance.Importance randomImportance() { - return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble()); + return new TotalFeatureImportance.Importance( + randomDouble(), + randomDouble(), + randomDouble(), + randomBoolean() ? null : randomDouble()); } @Before