diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 46c57b4a40c..758a1700878 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -779,9 +779,9 @@ final class MLRequestConverters { params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION, Boolean.toString(getTrainedModelsRequest.getDecompressDefinition())); } - if (getTrainedModelsRequest.getIncludeDefinition() != null) { - params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION, - Boolean.toString(getTrainedModelsRequest.getIncludeDefinition())); + if (getTrainedModelsRequest.getIncludes().isEmpty() == false) { + params.putParam(GetTrainedModelsRequest.INCLUDE, + Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes())); } if (getTrainedModelsRequest.getTags() != null) { params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags())); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java index ca0284de84d..29fb67b3e75 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java @@ -26,21 +26,26 @@ import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.common.Nullable; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; public class GetTrainedModelsRequest implements Validatable { + private static final String DEFINITION = "definition"; + private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance"; public static final String ALLOW_NO_MATCH = "allow_no_match"; - public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; public static final String FOR_EXPORT = "for_export"; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String TAGS = "tags"; + public static final String INCLUDE = "include"; private final List ids; private Boolean allowNoMatch; - private Boolean includeDefinition; + private Set includes = new HashSet<>(); private Boolean decompressDefinition; private Boolean forExport; private PageParams pageParams; @@ -86,19 +91,32 @@ public class GetTrainedModelsRequest implements Validatable { return this; } - public Boolean getIncludeDefinition() { - return includeDefinition; + public Set getIncludes() { + return Collections.unmodifiableSet(includes); + } + + public GetTrainedModelsRequest includeDefinition() { + this.includes.add(DEFINITION); + return this; + } + + public GetTrainedModelsRequest includeTotalFeatureImportance() { + this.includes.add(TOTAL_FEATURE_IMPORTANCE); + return this; } /** * Whether to include the full model definition. * * The full model definition can be very large. - * + * @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()} * @param includeDefinition If {@code true}, the definition is included. */ + @Deprecated public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) { - this.includeDefinition = includeDefinition; + if (includeDefinition != null && includeDefinition) { + return this.includeDefinition(); + } return this; } @@ -173,13 +191,13 @@ public class GetTrainedModelsRequest implements Validatable { return Objects.equals(ids, other.ids) && Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(decompressDefinition, other.decompressDefinition) - && Objects.equals(includeDefinition, other.includeDefinition) + && Objects.equals(includes, other.includes) && Objects.equals(forExport, other.forExport) && Objects.equals(pageParams, other.pageParams); } @Override public int hashCode() { - return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport); + return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java new file mode 100644 index 00000000000..882dc046d6d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -0,0 +1,208 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TotalFeatureImportance implements ToXContentObject { + + private static final String NAME = "total_feature_importance"; + public static final ParseField FEATURE_NAME = new ParseField("feature_name"); + public static final ParseField IMPORTANCE = new ParseField("importance"); + public static final ParseField CLASSES = new ParseField("classes"); + public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); + public static final ParseField MIN = new ParseField("min"); + public static final ParseField MAX = new ParseField("max"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List)a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE); + PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES); + } + + public static TotalFeatureImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public final String featureName; + public final Importance importance; + public final List classImportances; + + TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List classImportances) { + this.featureName = featureName; + this.importance = importance; + this.classImportances = classImportances == null ? Collections.emptyList() : classImportances; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAME.getPreferredName(), featureName); + if (importance != null) { + builder.field(IMPORTANCE.getPreferredName(), importance); + } + if (classImportances.isEmpty() == false) { + builder.field(CLASSES.getPreferredName(), classImportances); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TotalFeatureImportance that = (TotalFeatureImportance) o; + return Objects.equals(that.importance, importance) + && Objects.equals(featureName, that.featureName) + && Objects.equals(classImportances, that.classImportances); + } + + @Override + public int hashCode() { + return Objects.hash(featureName, importance, classImportances); + } + + public static class Importance implements ToXContentObject { + private static final String NAME = "importance"; + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + + static { + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN); + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + } + + private final double meanMagnitude; + private final double min; + private final double max; + + public Importance(double meanMagnitude, double min, double max) { + this.meanMagnitude = meanMagnitude; + this.min = min; + this.max = max; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Importance that = (Importance) o; + return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && + Double.compare(that.min, min) == 0 && + Double.compare(that.max, max) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(meanMagnitude, min, max); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); + builder.field(MIN.getPreferredName(), min); + builder.field(MAX.getPreferredName(), max); + builder.endObject(); + return builder; + } + } + + public static class ClassImportance implements ToXContentObject { + private static final String NAME = "total_class_importance"; + + public static final ParseField CLASS_NAME = new ParseField("class_name"); + public static final ParseField IMPORTANCE = new ParseField("importance"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new ClassImportance(a[0], (Importance)a[1])); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return p.text(); + } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) { + return p.numberValue(); + } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) { + return p.booleanValue(); + } + throw new XContentParseException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_NAME, ObjectParser.ValueType.VALUE); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE); + } + + public static ClassImportance fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public final Object className; + public final Importance importance; + + ClassImportance(Object className, Importance importance) { + this.className = className; + this.importance = importance; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(IMPORTANCE.getPreferredName(), importance); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Objects.equals(that.importance, importance) && Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 5a83ab9de88..75958f8e433 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -894,7 +894,7 @@ public class MLRequestConvertersTests extends ESTestCase { GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3) .setAllowNoMatch(false) .setDecompressDefinition(true) - .setIncludeDefinition(false) + .includeDefinition() .setTags("tag1", "tag2") .setPageParams(new PageParams(100, 300)); @@ -908,7 +908,7 @@ public class MLRequestConvertersTests extends ESTestCase { hasEntry("allow_no_match", "false"), hasEntry("decompress_definition", "true"), hasEntry("tags", "tag1,tag2"), - hasEntry("include_model_definition", "false") + hasEntry("include", "definition") )); assertNull(request.getEntity()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 577e62f23e3..36608760079 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -2257,7 +2257,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { { GetTrainedModelsResponse getTrainedModelsResponse = execute( - new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true), + new GetTrainedModelsRequest(modelIdPrefix + 0) + .setDecompressDefinition(true) + .includeDefinition() + .includeTotalFeatureImportance(), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); @@ -2268,7 +2271,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); getTrainedModelsResponse = execute( - new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true), + new GetTrainedModelsRequest(modelIdPrefix + 0) + .setDecompressDefinition(false) + .includeTotalFeatureImportance() + .includeDefinition(), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); @@ -2279,7 +2285,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); getTrainedModelsResponse = execute( - new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false), + new GetTrainedModelsRequest(modelIdPrefix + 0) + .setDecompressDefinition(false), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 87e77f7074e..ccfa24c7fa1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3694,11 +3694,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { // tag::get-trained-models-request GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1> .setPageParams(new PageParams(0, 1)) // <2> - .setIncludeDefinition(false) // <3> - .setDecompressDefinition(false) // <4> - .setAllowNoMatch(true) // <5> - .setTags("regression") // <6> - .setForExport(false); // <7> + .includeDefinition() // <3> + .includeTotalFeatureImportance() // <4> + .setDecompressDefinition(false) // <5> + .setAllowNoMatch(true) // <6> + .setTags("regression") // <7> + .setForExport(false); // <8> // end::get-trained-models-request request.setTags((List)null); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java new file mode 100644 index 00000000000..eef5c3bae21 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TotalFeatureImportanceTests extends AbstractXContentTestCase { + + + public static TotalFeatureImportance randomInstance() { + return new TotalFeatureImportance( + randomAlphaOfLength(10), + randomBoolean() ? null : randomImportance(), + randomBoolean() ? + null : + Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance())) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList()) + ); + } + + private static TotalFeatureImportance.Importance randomImportance() { + return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble()); + } + + @Override + protected TotalFeatureImportance createTestInstance() { + return randomInstance(); + } + + @Override + protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException { + return TotalFeatureImportance.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + +} diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc index ffaea526f01..275b4c54292 100644 --- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc @@ -22,26 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models. -------------------------------------------------- include-tagged::{doc-tests-file}[{api}-request] -------------------------------------------------- -<1> Constructing a new GET request referencing an existing Trained Model +<1> Constructing a new GET request referencing an existing trained model <2> Set the paging parameters <3> Indicate if the complete model definition should be included -<4> Should the definition be fully decompressed on GET -<5> Allow empty response if no Trained Models match the provided ID patterns. - If false, an error will be thrown if no Trained Models match the +<4> Indicate if the total feature importance for the features used in training + should be included in the model `metadata` field. +<5> Should the definition be fully decompressed on GET +<6> Allow empty response if no trained models match the provided ID patterns. + If false, an error will be thrown if no trained models match the ID patterns. -<6> An optional list of tags used to narrow the model search. A Trained Model +<7> An optional list of tags used to narrow the model search. A trained model can have many tags or none. The trained models in the response will contain all the provided tags. -<7> Optional boolean value indicating if certain fields should be removed on - retrieval. This is useful for getting the trained model in a format that - can then be put into another cluster. +<8> Optional boolean value for requesting the trained model in a format that can + then be put into another cluster. Certain fields that can only be set when + the model is imported are removed. include::../execution.asciidoc[] [id="{upid}-{api}-response"] ==== Response -The returned +{response}+ contains the requested Trained Model. +The returned +{response}+ contains the requested trained model. ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc index 52b5b7372ee..1c69753b6fe 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc @@ -29,18 +29,19 @@ experimental[] [[ml-get-inference-prereq]] == {api-prereq-title} -Required privileges which should be added to a custom role: +If the {es} {security-features} are enabled, you must have the following +privileges: * cluster: `monitor_ml` - -For more information, see <> and + +For more information, see <> and {ml-docs-setup-privileges}. [[ml-get-inference-desc]] == {api-description-title} -You can get information for multiple trained models in a single API request by +You can get information for multiple trained models in a single API request by using a comma-separated list of model IDs or a wildcard expression. @@ -48,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression. == {api-path-parms-title} ``:: -(Optional, string) +(Optional, string) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] @@ -56,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] == {api-query-parms-title} `allow_no_match`:: -(Optional, boolean) +(Optional, boolean) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models] `decompress_definition`:: (Optional, boolean) -Specifies whether the included model definition should be returned as a JSON map +Specifies whether the included model definition should be returned as a JSON map (`true`) or in a custom compressed format (`false`). Defaults to `true`. `for_export`:: @@ -71,17 +72,21 @@ retrieval. This allows the model to be in an acceptable format to be retrieved and then added to another cluster. Default is false. `from`:: -(Optional, integer) +(Optional, integer) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models] -`include_model_definition`:: -(Optional, boolean) -Specifies whether the model definition is returned in the response. Defaults to -`false`. When `true`, only a single model must match the ID patterns provided. -Otherwise, a bad request is returned. +`include`:: +(Optional, string) +A comma delimited string of optional fields to include in the response body. +Valid options are: + - `definition`: Includes the model definition + - `total_feature_importance`: Includes the total feature importance for the + training data set. This field is available in the `metadata` field in the + response body. +Default is empty, indicating including no optional fields. `size`:: -(Optional, integer) +(Optional, integer) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models] `tags`:: @@ -94,7 +99,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags] `trained_model_configs`:: (array) -An array of trained model resources, which are sorted by the `model_id` value in +An array of trained model resources, which are sorted by the `model_id` value in ascending order. + .Properties of trained model resources @@ -132,8 +137,86 @@ The license level of the trained model. `metadata`::: (object) -An object containing metadata about the trained model. For example, models +An object containing metadata about the trained model. For example, models created by {dfanalytics} contain `analysis_config` and `input` objects. +.Properties of metadata +[%collapsible%open] +===== +`total_feature_importance`::: +(array) +An array of the total feature importance for each feature used from +the training data set. This array of objects is returned if {dfanalytics} trained +the model and the request includes `total_feature_importance` in the `include` +request parameter. ++ +.Properties of total feature importance +[%collapsible%open] +====== + +`feature_name`::: +(string) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name] + +`importance`::: +(object) +A collection of feature importance statistics related to the training data set for this particular feature. ++ +.Properties of feature importance +[%collapsible%open] +======= +`mean_magnitude`::: +(double) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] + +`max`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] + +`min`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] + +======= + +`classes`::: +(array) +If the trained model is a classification model, feature importance statistics are gathered +per target class value. ++ +.Properties of class feature importance +[%collapsible%open] + +======= + +`class_name`::: +(string) +The target class value. Could be a string, boolean, or number. + +`importance`::: +(object) +A collection of feature importance statistics related to the training data set for this particular feature. ++ +.Properties of feature importance +[%collapsible%open] +======== +`mean_magnitude`::: +(double) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude] + +`max`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max] + +`min`::: +(int) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min] + +======== + +======= + +====== +===== `model_id`::: (string) @@ -152,13 +235,13 @@ The {es} version number in which the trained model was created. == {api-response-codes-title} `400`:: - If `include_model_definition` is `true`, this code indicates that more than + If `include_model_definition` is `true`, this code indicates that more than one models match the ID pattern. `404` (Missing resources):: If `allow_no_match` is `false`, this code indicates that there are no resources that match the request or only partial matches for the request. - + [[ml-get-inference-example]] == {api-examples-title} diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 730a82d7e45..b9172c497e8 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -780,6 +780,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that used to train the model, which defaults to `_prediction`. end::inference-config-results-field-processor[] +tag::inference-metadata-feature-importance-feature-name[] +The training feature name for which this importance was calculated. +end::inference-metadata-feature-importance-feature-name[] +tag::inference-metadata-feature-importance-magnitude[] +The average magnitude of this feature across all the training data. +This value is the average of the absolute values of the importance +for this feature. +end::inference-metadata-feature-importance-magnitude[] +tag::inference-metadata-feature-importance-max[] +The maximum importance value across all the training data for this +feature. +end::inference-metadata-feature-importance-max[] +tag::inference-metadata-feature-importance-min[] +The minimum importance value across all the training data for this +feature. +end::inference-metadata-feature-importance-min[] + tag::influencers[] A comma separated list of influencer field names. Typically these can be the by, over, or partition fields that are used in the detector configuration. You might diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java index 84330f7924a..b15ceb19d87 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -10,15 +10,19 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; public class GetTrainedModelsAction extends ActionType { @@ -32,23 +36,60 @@ public class GetTrainedModelsAction extends ActionType KNOWN_INCLUDES; + static { + HashSet includes = new HashSet<>(2, 1.0f); + includes.add(DEFINITION); + includes.add(TOTAL_FEATURE_IMPORTANCE); + KNOWN_INCLUDES = Collections.unmodifiableSet(includes); + } + public static final ParseField INCLUDE = new ParseField("include"); + public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); public static final ParseField TAGS = new ParseField("tags"); - private final boolean includeModelDefinition; + private final Set includes; private final List tags; + @Deprecated public Request(String id, boolean includeModelDefinition, List tags) { setResourceId(id); setAllowNoResources(true); - this.includeModelDefinition = includeModelDefinition; this.tags = tags == null ? Collections.emptyList() : tags; + if (includeModelDefinition) { + this.includes = new HashSet<>(Collections.singletonList(DEFINITION)); + } else { + this.includes = Collections.emptySet(); + } + } + + public Request(String id, List tags, Set includes) { + setResourceId(id); + setAllowNoResources(true); + this.tags = tags == null ? Collections.emptyList() : tags; + this.includes = includes == null ? Collections.emptySet() : includes; + Set unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES); + if (unknownIncludes.isEmpty() == false) { + throw ExceptionsHelper.badRequestException( + "unknown [include] parameters {}. Valid options are {}", + unknownIncludes, + KNOWN_INCLUDES); + } } public Request(StreamInput in) throws IOException { super(in); - this.includeModelDefinition = in.readBoolean(); + if (in.getVersion().onOrAfter(Version.V_7_10_0)) { + this.includes = in.readSet(StreamInput::readString); + } else { + Set includes = new HashSet<>(); + if (in.readBoolean()) { + includes.add(DEFINITION); + } + this.includes = includes; + } if (in.getVersion().onOrAfter(Version.V_7_7_0)) { this.tags = in.readStringList(); } else { @@ -62,7 +103,11 @@ public class GetTrainedModelsAction extends ActionType getTags() { @@ -72,7 +117,11 @@ public class GetTrainedModelsAction extends ActionType RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE); private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; @@ -419,7 +423,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition); this.description = config.getDescription(); this.tags = config.getTags(); - this.metadata = config.getMetadata(); + this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata()); this.input = config.getInput(); this.estimatedOperations = config.estimatedOperations; this.estimatedHeapMemory = config.estimatedHeapMemory; @@ -471,6 +475,18 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } + public Builder setFeatureImportance(List totalFeatureImportance) { + if (totalFeatureImportance == null) { + return this; + } + if (this.metadata == null) { + this.metadata = new HashMap<>(); + } + this.metadata.put(TOTAL_FEATURE_IMPORTANCE, + totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList())); + return this; + } + public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) { if (definition == null) { return this; @@ -627,6 +643,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { ESTIMATED_OPERATIONS.getPreferredName(), validationException); validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException); + if (metadata != null) { + validationException = checkIllegalSetting( + metadata.get(TOTAL_FEATURE_IMPORTANCE), + METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE, + validationException); + } } if (validationException != null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java index 9f2df2b7512..8676af6ff5c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -20,8 +20,11 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; public class TotalFeatureImportance implements ToXContentObject, Writeable { @@ -81,16 +84,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(FEATURE_NAME.getPreferredName(), featureName); - if (importance != null) { - builder.field(IMPORTANCE.getPreferredName(), importance); - } - if (classImportances.isEmpty() == false) { - builder.field(CLASSES.getPreferredName(), classImportances); - } - builder.endObject(); - return builder; + return builder.map(asMap()); } @Override @@ -103,6 +97,18 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { && Objects.equals(classImportances, that.classImportances); } + public Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(FEATURE_NAME.getPreferredName(), featureName); + if (importance != null) { + map.put(IMPORTANCE.getPreferredName(), importance.asMap()); + } + if (classImportances.isEmpty() == false) { + map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList())); + } + return map; + } + @Override public int hashCode() { return Objects.hash(featureName, importance, classImportances); @@ -165,12 +171,15 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); - builder.field(MIN.getPreferredName(), min); - builder.field(MAX.getPreferredName(), max); - builder.endObject(); - return builder; + return builder.map(asMap()); + } + + private Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); + map.put(MIN.getPreferredName(), min); + map.put(MAX.getPreferredName(), max); + return map; } } @@ -229,11 +238,14 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(CLASS_NAME.getPreferredName(), className); - builder.field(IMPORTANCE.getPreferredName(), importance); - builder.endObject(); - return builder; + return builder.map(asMap()); + } + + private Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(CLASS_NAME.getPreferredName(), className); + map.put(IMPORTANCE.getPreferredName(), importance.asMap()); + return map; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java index dc3e8fc54d9..dd2662cf400 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java @@ -53,6 +53,10 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable { return NAME + "-" + modelId; } + public static String modelId(String docId) { + return docId.substring(NAME.length() + 1); + } + private final List totalFeatureImportances; private final String modelId; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 6a7965a01b2..4d18a8f1006 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -103,7 +103,7 @@ public final class Messages { public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; - public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]"; + public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}"; public static final String INFERENCE_CANNOT_DELETE_MODEL = "Unable to delete model [{0}]"; public static final String MODEL_DEFINITION_TRUNCATED = diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java index 7955117e117..a761a00c1dc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -5,19 +5,28 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; -public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase { +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase { @Override protected Request createTestInstance() { Request request = new Request(randomAlphaOfLength(20), - randomBoolean(), randomBoolean() ? null : - randomList(10, () -> randomAlphaOfLength(10))); + randomList(10, () -> randomAlphaOfLength(10)), + randomBoolean() ? null : + Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE)) + .limit(4) + .collect(Collectors.toSet())); request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); return request; } @@ -26,4 +35,22 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas protected Writeable.Reader instanceReader() { return Request::new; } + + @Override + protected Request mutateInstanceForVersion(Request instance, Version version) { + if (version.before(Version.V_7_10_0)) { + Set includes = new HashSet<>(); + if (instance.isIncludeModelDefinition()) { + includes.add(Request.DEFINITION); + } + Request request = new Request( + instance.getResourceId(), + version.before(Version.V_7_7_0) ? null : instance.getTags(), + includes); + request.setPageParams(instance.getPageParams()); + request.setAllowNoResources(instance.isAllowNoResources()); + return request; + } + return instance; + } } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index b9549842333..defde90095d 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -42,11 +42,13 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.startsWith; public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { @@ -95,19 +97,21 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); Tuple> ids = getIdsFuture.actionGet(); assertThat(ids.v1(), equalTo(1L)); + String inferenceModelId = ids.v2().iterator().next(); PlainActionFuture getTrainedModelFuture = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture); + trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture); TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet(); assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance")); - PlainActionFuture getTrainedMetadataFuture = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture); + PlainActionFuture> getTrainedMetadataFuture = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture); - TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet(); + TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId); assertThat(storedMetadata.getModelId(), startsWith(modelId)); assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances())); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index b920c5686dd..ca4a5723b88 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -90,7 +90,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { assertThat(exceptionHolder.get(), is(nullValue())); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); getConfigHolder.get().ensureParsedDefinition(xContentRegistry()); assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), equalTo(config)); @@ -121,7 +124,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { assertThat(exceptionHolder.get(), is(nullValue())); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder); + blockingCall(listener -> + trainedModelProvider.getTrainedModel(modelId, false, false, listener), + getConfigHolder, + exceptionHolder); getConfigHolder.get().ensureParsedDefinition(xContentRegistry()); assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition)); @@ -132,7 +138,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { String modelId = "test-get-missing-trained-model-config"; AtomicReference getConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); @@ -154,7 +163,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { .actionGet(); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); @@ -193,7 +205,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { } AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(getConfigHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); @@ -238,7 +253,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { } } AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + blockingCall( + listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener), + getConfigHolder, + exceptionHolder); assertThat(getConfigHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index 1ffc13b8b11..ba3edf91f91 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -57,15 +57,25 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), - listener::onFailure - )); + provider.getTrainedModel( + totalAndIds.v2().iterator().next(), + true, + request.isIncludeTotalFeatureImportance(), + ActionListener.wrap( + config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), + listener::onFailure + ) + ); } else { - provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap( - configs -> listener.onResponse(responseBuilder.setModels(configs).build()), - listener::onFailure - )); + provider.getTrainedModels( + totalAndIds.v2(), + request.isAllowNoResources(), + request.isIncludeTotalFeatureImportance(), + ActionListener.wrap( + configs -> listener.onResponse(responseBuilder.setModels(configs).build()), + listener::onFailure + ) + ); } }, listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index c94c668a87b..2483a2cffaa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -82,7 +82,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction { responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel())); if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 1c48ce9f141..838f06aacb7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -270,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener { } private void loadModel(String modelId, Consumer consumer) { - provider.getTrainedModel(modelId, false, ActionListener.wrap( + provider.getTrainedModel(modelId, false, false, ActionListener.wrap( trainedModelConfig -> { trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); provider.getTrainedModelForInference(modelId, ActionListener.wrap( @@ -306,7 +306,7 @@ public class ModelLoadingService implements ClusterStateListener { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); - provider.getTrainedModel(modelId, false, ActionListener.wrap( + provider.getTrainedModel(modelId, false, false, ActionListener.wrap( trainedModelConfig -> { // Verify we can pull the model into memory without causing OOM trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); @@ -434,7 +434,7 @@ public class ModelLoadingService implements ClusterStateListener { logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]", notification.getValue().model.getModelId())); - + // If the model is no longer referenced, flush the stats to persist as soon as possible notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); } finally { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 0e897e210d7..1ea70da656d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -89,9 +89,11 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.TreeSet; import java.util.stream.Collectors; @@ -235,14 +237,14 @@ public class TrainedModelProvider { )); } - public void getTrainedModelMetadata(String modelId, ActionListener listener) { + public void getTrainedModelMetadata(Collection modelIds, ActionListener> listener) { SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders .boolQuery() - .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds)) .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelMetadata.NAME)))) - .setSize(1) + .setSize(10_000) // First find the latest index .addSort("_index", SortOrder.DESC) .request(); @@ -250,18 +252,20 @@ public class TrainedModelProvider { searchResponse -> { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); + Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds))); return; } - List metadataList = handleHits(searchResponse.getHits().getHits(), - modelId, - this::parseMetadataLenientlyFromSource); - listener.onResponse(metadataList.get(0)); + HashMap map = new HashMap<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId())); + map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId)); + } + listener.onResponse(map); }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); + Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds))); return; } listener.onFailure(e); @@ -371,7 +375,7 @@ public class TrainedModelProvider { // TODO Change this when we get more than just langIdent stored if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { - TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry); + TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry); assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; listener.onResponse( InferenceDefinition.builder() @@ -434,18 +438,50 @@ public class TrainedModelProvider { )); } - public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener listener) { + public void getTrainedModel(final String modelId, + final boolean includeDefinition, + final boolean includeTotalFeatureImportance, + final ActionListener finalListener) { if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { - listener.onResponse(loadModelFromResource(modelId, includeDefinition == false)); + finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build()); return; } catch (ElasticsearchException ex) { - listener.onFailure(ex); + finalListener.onFailure(ex); return; } } + ActionListener getTrainedModelListener = ActionListener.wrap( + modelBuilder -> { + if (includeTotalFeatureImportance == false) { + finalListener.onResponse(modelBuilder.build()); + return; + } + this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap( + metadata -> { + TrainedModelMetadata modelMetadata = metadata.get(modelId); + if (modelMetadata != null) { + modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances()); + } + finalListener.onResponse(modelBuilder.build()); + }, + failure -> { + // total feature importance is not necessary for a model to be valid + // we shouldn't fail if it is not found + if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { + finalListener.onResponse(modelBuilder.build()); + return; + } + finalListener.onFailure(failure); + } + )); + + }, + finalListener::onFailure + ); + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders .idsQuery() .addIds(modelId)); @@ -483,11 +519,11 @@ public class TrainedModelProvider { try { builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource); } catch (ResourceNotFoundException ex) { - listener.onFailure(new ResourceNotFoundException( + getTrainedModelListener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return; } catch (Exception ex) { - listener.onFailure(ex); + getTrainedModelListener.onFailure(ex); return; } @@ -500,22 +536,22 @@ public class TrainedModelProvider { String compressedString = getDefinitionFromDocs(docs, modelId); builder.setDefinitionFromString(compressedString); } catch (ElasticsearchException elasticsearchException) { - listener.onFailure(elasticsearchException); + getTrainedModelListener.onFailure(elasticsearchException); return; } } catch (ResourceNotFoundException ex) { - listener.onFailure(new ResourceNotFoundException( + getTrainedModelListener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); return; } catch (Exception ex) { - listener.onFailure(ex); + getTrainedModelListener.onFailure(ex); return; } } - listener.onResponse(builder.build()); + getTrainedModelListener.onResponse(builder); }, - listener::onFailure + getTrainedModelListener::onFailure ); executeAsyncWithOrigin(client, @@ -532,7 +568,10 @@ public class TrainedModelProvider { * This does no expansion on the ids. * It assumes that there are fewer than 10k. */ - public void getTrainedModels(Set modelIds, boolean allowNoResources, final ActionListener> listener) { + public void getTrainedModels(Set modelIds, + boolean allowNoResources, + boolean includeTotalFeatureImportance, + final ActionListener> finalListener) { QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) @@ -541,23 +580,63 @@ public class TrainedModelProvider { .setQuery(queryBuilder) .setSize(modelIds.size()) .request(); - List configs = new ArrayList<>(modelIds.size()); + List configs = new ArrayList<>(modelIds.size()); Set modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE); Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds); for(String modelId : modelsAsResource) { try { configs.add(loadModelFromResource(modelId, true)); } catch (ElasticsearchException ex) { - listener.onFailure(ex); + finalListener.onFailure(ex); return; } } if (modelsInIndex.isEmpty()) { - configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); - listener.onResponse(configs); + finalListener.onResponse(configs.stream() + .map(TrainedModelConfig.Builder::build) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())); return; } + ActionListener> getTrainedModelListener = ActionListener.wrap( + modelBuilders -> { + if (includeTotalFeatureImportance == false) { + finalListener.onResponse(modelBuilders.stream() + .map(TrainedModelConfig.Builder::build) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())); + return; + } + this.getTrainedModelMetadata(modelIds, ActionListener.wrap( + metadata -> + finalListener.onResponse(modelBuilders.stream() + .map(builder -> { + TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId()); + if (modelMetadata != null) { + builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances()); + } + return builder.build(); + }) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())), + failure -> { + // total feature importance is not necessary for a model to be valid + // we shouldn't fail if it is not found + if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) { + finalListener.onResponse(modelBuilders.stream() + .map(TrainedModelConfig.Builder::build) + .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) + .collect(Collectors.toList())); + return; + } + finalListener.onFailure(failure); + } + )); + }, + finalListener::onFailure + ); + ActionListener configSearchHandler = ActionListener.wrap( searchResponse -> { Set observedIds = new HashSet<>( @@ -568,12 +647,12 @@ public class TrainedModelProvider { try { if (observedIds.contains(searchHit.getId()) == false) { configs.add( - parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build() + parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()) ); observedIds.add(searchHit.getId()); } } catch (IOException ex) { - listener.onFailure( + getTrainedModelListener.onFailure( ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId())); return; } @@ -583,14 +662,13 @@ public class TrainedModelProvider { // Otherwise, treat it as if it was never expanded to begin with. Set missingConfigs = Sets.difference(modelIds, observedIds); if (missingConfigs.isEmpty() == false && allowNoResources == false) { - listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); + getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); return; } // Ensure sorted even with the injection of locally resourced models - configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); - listener.onResponse(configs); + getTrainedModelListener.onResponse(configs); }, - listener::onFailure + getTrainedModelListener::onFailure ); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler); @@ -639,7 +717,7 @@ public class TrainedModelProvider { foundResourceIds = new HashSet<>(); for(String resourceId : matchedResourceIds) { // Does the model as a resource have all the tags? - if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { + if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) { foundResourceIds.add(resourceId); } } @@ -833,7 +911,7 @@ public class TrainedModelProvider { return QueryBuilders.constantScoreQuery(boolQueryBuilder); } - TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) { + TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) { URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT); if (resource == null) { logger.error("[{}] presumed stored as a resource but not found", modelId); @@ -848,7 +926,7 @@ public class TrainedModelProvider { if (nullOutDefinition) { builder.clearDefinition(); } - return builder.build(); + return builder; } catch (IOException ioEx) { logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx); throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 7cf6ae7b914..5de16dfb319 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -56,12 +57,17 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { if (Strings.isNullOrEmpty(modelId)) { modelId = Metadata.ALL; } - boolean includeModelDefinition = restRequest.paramAsBoolean( - GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(), - false - ); List tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY)); - GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags); + Set includes = new HashSet<>( + asList( + restRequest.paramAsStringArray( + GetTrainedModelsAction.Request.INCLUDE.getPreferredName(), + Strings.EMPTY_ARRAY))); + final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ? + new GetTrainedModelsAction.Request(modelId, + restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false), + tags) : + new GetTrainedModelsAction.Request(modelId, tags, includes); if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 8d8f54120cf..ee6a9206c9a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -437,9 +437,9 @@ public class ModelLoadingServiceTests extends ESTestCase { // the loading occurred or which models are currently in the cache due to evictions. // Verify that we have at least loaded all three assertBusy(() -> { - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any()); - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any()); - verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any()); }); assertBusy(() -> { assertThat(circuitBreaker.getUsed(), equalTo(10L)); @@ -553,10 +553,10 @@ public class ModelLoadingServiceTests extends ESTestCase { }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3]; listener.onResponse(trainedModelConfig); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any()); } @SuppressWarnings("unchecked") @@ -564,20 +564,20 @@ public class ModelLoadingServiceTests extends ESTestCase { if (randomBoolean()) { doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3]; listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any()); } else { TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3]; listener.onResponse(trainedModelConfig); return null; - }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any()); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index aee4c43f227..037b9ccc93e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -57,14 +57,14 @@ public class TrainedModelProviderTests extends ESTestCase { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) { PlainActionFuture future = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModel(modelId, true, future); + trainedModelProvider.getTrainedModel(modelId, true, false, future); TrainedModelConfig configWithDefinition = future.actionGet(); assertThat(configWithDefinition.getModelId(), equalTo(modelId)); assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue()))); PlainActionFuture futureNoDefinition = new PlainActionFuture<>(); - trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition); + trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition); TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet(); assertThat(configWithoutDefinition.getModelId(), equalTo(modelId)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index bc6061b3d92..8eaa8f9f7c4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -33,7 +33,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); PlainActionFuture future = new PlainActionFuture<>(); // Should be OK as we don't make any client calls - trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future); + trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future); TrainedModelConfig config = future.actionGet(); config.ensureParsedDefinition(xContentRegistry()); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index 168d233c8e3..a30cd14a752 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -34,11 +34,10 @@ "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", "default":true }, - "include_model_definition":{ - "type":"boolean", + "include":{ + "type":"string", "required":false, - "description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.", - "default":false + "description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none." }, "decompress_definition":{ "type":"boolean", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 5a437fc4166..31bda1d9a27 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -1,6 +1,24 @@ setup: - skip: - features: headers + features: + - headers + - allowed_warnings + - do: + allowed_warnings: + - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" + headers: + Content-Type: application/json + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + index: + id: trained_model_metadata-a-regression-model-0 + index: .ml-inference-000003 + body: + model_id: "a-regression-model-0" + doc_type: "trained_model_metadata" + total_feature_importance: + - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} + - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }} + - do: headers: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser @@ -548,6 +566,20 @@ setup: - match: { count: 12 } - match: { trained_model_configs.0.model_id: "a-regression-model-1" } --- +"Test get models with include total feature importance": + - do: + ml.get_trained_models: + model_id: "a-regression-model-*" + include: "total_feature_importance" + - match: { count: 2 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + - is_true: trained_model_configs.0.metadata.total_feature_importance + - length: { trained_model_configs.0.metadata.total_feature_importance: 2 } + - match: { trained_model_configs.1.model_id: "a-regression-model-1" } + - is_false: trained_model_configs.1.metadata.total_feature_importance + +--- "Test delete given unused trained model": - do: ml.delete_trained_model: @@ -824,7 +856,7 @@ setup: ml.get_trained_models: model_id: "a-regression-model-1" for_export: true - include_model_definition: true + include: "definition" decompress_definition: false - match: { trained_model_configs.0.description: "empty model for tests" }