[7.x] [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) (#62620)

* [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922)

Adds new flag include to the get trained models API
The flag initially has two valid values: definition, total_feature_importance.
Consequently, the old include_model_definition flag is now deprecated.
When total_feature_importance is included, the total_feature_importance field is included in the model metadata object.
Including definition is the same as previously setting include_model_definition=true.

* fixing test

* Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
This commit is contained in:
Benjamin Trent 2020-09-18 10:07:35 -04:00 committed by GitHub
parent e1a4a3073a
commit e163559e4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 833 additions and 162 deletions

View File

@ -779,9 +779,9 @@ final class MLRequestConverters {
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION, params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition())); Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
} }
if (getTrainedModelsRequest.getIncludeDefinition() != null) { if (getTrainedModelsRequest.getIncludes().isEmpty() == false) {
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION, params.putParam(GetTrainedModelsRequest.INCLUDE,
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition())); Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes()));
} }
if (getTrainedModelsRequest.getTags() != null) { if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags())); params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));

View File

@ -26,21 +26,26 @@ import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
public class GetTrainedModelsRequest implements Validatable { public class GetTrainedModelsRequest implements Validatable {
private static final String DEFINITION = "definition";
private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
public static final String ALLOW_NO_MATCH = "allow_no_match"; public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String FOR_EXPORT = "for_export"; public static final String FOR_EXPORT = "for_export";
public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags"; public static final String TAGS = "tags";
public static final String INCLUDE = "include";
private final List<String> ids; private final List<String> ids;
private Boolean allowNoMatch; private Boolean allowNoMatch;
private Boolean includeDefinition; private Set<String> includes = new HashSet<>();
private Boolean decompressDefinition; private Boolean decompressDefinition;
private Boolean forExport; private Boolean forExport;
private PageParams pageParams; private PageParams pageParams;
@ -86,19 +91,32 @@ public class GetTrainedModelsRequest implements Validatable {
return this; return this;
} }
public Boolean getIncludeDefinition() { public Set<String> getIncludes() {
return includeDefinition; return Collections.unmodifiableSet(includes);
}
public GetTrainedModelsRequest includeDefinition() {
this.includes.add(DEFINITION);
return this;
}
public GetTrainedModelsRequest includeTotalFeatureImportance() {
this.includes.add(TOTAL_FEATURE_IMPORTANCE);
return this;
} }
/** /**
* Whether to include the full model definition. * Whether to include the full model definition.
* *
* The full model definition can be very large. * The full model definition can be very large.
* * @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()}
* @param includeDefinition If {@code true}, the definition is included. * @param includeDefinition If {@code true}, the definition is included.
*/ */
@Deprecated
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) { public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
this.includeDefinition = includeDefinition; if (includeDefinition != null && includeDefinition) {
return this.includeDefinition();
}
return this; return this;
} }
@ -173,13 +191,13 @@ public class GetTrainedModelsRequest implements Validatable {
return Objects.equals(ids, other.ids) return Objects.equals(ids, other.ids)
&& Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(allowNoMatch, other.allowNoMatch)
&& Objects.equals(decompressDefinition, other.decompressDefinition) && Objects.equals(decompressDefinition, other.decompressDefinition)
&& Objects.equals(includeDefinition, other.includeDefinition) && Objects.equals(includes, other.includes)
&& Objects.equals(forExport, other.forExport) && Objects.equals(forExport, other.forExport)
&& Objects.equals(pageParams, other.pageParams); && Objects.equals(pageParams, other.pageParams);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport); return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport);
} }
} }

View File

@ -0,0 +1,208 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TotalFeatureImportance implements ToXContentObject {
private static final String NAME = "total_feature_importance";
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ParseField CLASSES = new ParseField("classes");
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
}
public static TotalFeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public final String featureName;
public final Importance importance;
public final List<ClassImportance> classImportances;
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
this.featureName = featureName;
this.importance = importance;
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
builder.field(IMPORTANCE.getPreferredName(), importance);
}
if (classImportances.isEmpty() == false) {
builder.field(CLASSES.getPreferredName(), classImportances);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TotalFeatureImportance that = (TotalFeatureImportance) o;
return Objects.equals(that.importance, importance)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(classImportances, that.classImportances);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportances);
}
public static class Importance implements ToXContentObject {
private static final String NAME = "importance";
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
static {
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
}
private final double meanMagnitude;
private final double min;
private final double max;
public Importance(double meanMagnitude, double min, double max) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
}
@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max);
builder.endObject();
return builder;
}
}
public static class ClassImportance implements ToXContentObject {
private static final String NAME = "total_class_importance";
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new ClassImportance(a[0], (Importance)a[1]));
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
}
public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public final Object className;
public final Importance importance;
ClassImportance(Object className, Importance importance) {
this.className = className;
this.importance = importance;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(IMPORTANCE.getPreferredName(), importance);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassImportance that = (ClassImportance) o;
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
}
@Override
public int hashCode() {
return Objects.hash(className, importance);
}
}
}

View File

@ -894,7 +894,7 @@ public class MLRequestConvertersTests extends ESTestCase {
GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3) GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
.setAllowNoMatch(false) .setAllowNoMatch(false)
.setDecompressDefinition(true) .setDecompressDefinition(true)
.setIncludeDefinition(false) .includeDefinition()
.setTags("tag1", "tag2") .setTags("tag1", "tag2")
.setPageParams(new PageParams(100, 300)); .setPageParams(new PageParams(100, 300));
@ -908,7 +908,7 @@ public class MLRequestConvertersTests extends ESTestCase {
hasEntry("allow_no_match", "false"), hasEntry("allow_no_match", "false"),
hasEntry("decompress_definition", "true"), hasEntry("decompress_definition", "true"),
hasEntry("tags", "tag1,tag2"), hasEntry("tags", "tag1,tag2"),
hasEntry("include_model_definition", "false") hasEntry("include", "definition")
)); ));
assertNull(request.getEntity()); assertNull(request.getEntity());
} }

View File

@ -2257,7 +2257,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
{ {
GetTrainedModelsResponse getTrainedModelsResponse = execute( GetTrainedModelsResponse getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true), new GetTrainedModelsRequest(modelIdPrefix + 0)
.setDecompressDefinition(true)
.includeDefinition()
.includeTotalFeatureImportance(),
machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync); machineLearningClient::getTrainedModelsAsync);
@ -2268,7 +2271,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
getTrainedModelsResponse = execute( getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true), new GetTrainedModelsRequest(modelIdPrefix + 0)
.setDecompressDefinition(false)
.includeTotalFeatureImportance()
.includeDefinition(),
machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync); machineLearningClient::getTrainedModelsAsync);
@ -2279,7 +2285,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0)); assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
getTrainedModelsResponse = execute( getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false), new GetTrainedModelsRequest(modelIdPrefix + 0)
.setDecompressDefinition(false),
machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync); machineLearningClient::getTrainedModelsAsync);
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L)); assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));

View File

@ -3694,11 +3694,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::get-trained-models-request // tag::get-trained-models-request
GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1> GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
.setPageParams(new PageParams(0, 1)) // <2> .setPageParams(new PageParams(0, 1)) // <2>
.setIncludeDefinition(false) // <3> .includeDefinition() // <3>
.setDecompressDefinition(false) // <4> .includeTotalFeatureImportance() // <4>
.setAllowNoMatch(true) // <5> .setDecompressDefinition(false) // <5>
.setTags("regression") // <6> .setAllowNoMatch(true) // <6>
.setForExport(false); // <7> .setTags("regression") // <7>
.setForExport(false); // <8>
// end::get-trained-models-request // end::get-trained-models-request
request.setTags((List<String>)null); request.setTags((List<String>)null);

View File

@ -0,0 +1,63 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
public static TotalFeatureImportance randomInstance() {
return new TotalFeatureImportance(
randomAlphaOfLength(10),
randomBoolean() ? null : randomImportance(),
randomBoolean() ?
null :
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())
);
}
private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
}
@Override
protected TotalFeatureImportance createTestInstance() {
return randomInstance();
}
@Override
protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
return TotalFeatureImportance.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -22,26 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models.
-------------------------------------------------- --------------------------------------------------
include-tagged::{doc-tests-file}[{api}-request] include-tagged::{doc-tests-file}[{api}-request]
-------------------------------------------------- --------------------------------------------------
<1> Constructing a new GET request referencing an existing Trained Model <1> Constructing a new GET request referencing an existing trained model
<2> Set the paging parameters <2> Set the paging parameters
<3> Indicate if the complete model definition should be included <3> Indicate if the complete model definition should be included
<4> Should the definition be fully decompressed on GET <4> Indicate if the total feature importance for the features used in training
<5> Allow empty response if no Trained Models match the provided ID patterns. should be included in the model `metadata` field.
If false, an error will be thrown if no Trained Models match the <5> Should the definition be fully decompressed on GET
<6> Allow empty response if no trained models match the provided ID patterns.
If false, an error will be thrown if no trained models match the
ID patterns. ID patterns.
<6> An optional list of tags used to narrow the model search. A Trained Model <7> An optional list of tags used to narrow the model search. A trained model
can have many tags or none. The trained models in the response will can have many tags or none. The trained models in the response will
contain all the provided tags. contain all the provided tags.
<7> Optional boolean value indicating if certain fields should be removed on <8> Optional boolean value for requesting the trained model in a format that can
retrieval. This is useful for getting the trained model in a format that then be put into another cluster. Certain fields that can only be set when
can then be put into another cluster. the model is imported are removed.
include::../execution.asciidoc[] include::../execution.asciidoc[]
[id="{upid}-{api}-response"] [id="{upid}-{api}-response"]
==== Response ==== Response
The returned +{response}+ contains the requested Trained Model. The returned +{response}+ contains the requested trained model.
["source","java",subs="attributes,callouts,macros"] ["source","java",subs="attributes,callouts,macros"]
-------------------------------------------------- --------------------------------------------------

View File

@ -29,18 +29,19 @@ experimental[]
[[ml-get-inference-prereq]] [[ml-get-inference-prereq]]
== {api-prereq-title} == {api-prereq-title}
Required privileges which should be added to a custom role: If the {es} {security-features} are enabled, you must have the following
privileges:
* cluster: `monitor_ml` * cluster: `monitor_ml`
For more information, see <<security-privileges>> and For more information, see <<security-privileges>> and
{ml-docs-setup-privileges}. {ml-docs-setup-privileges}.
[[ml-get-inference-desc]] [[ml-get-inference-desc]]
== {api-description-title} == {api-description-title}
You can get information for multiple trained models in a single API request by You can get information for multiple trained models in a single API request by
using a comma-separated list of model IDs or a wildcard expression. using a comma-separated list of model IDs or a wildcard expression.
@ -48,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression.
== {api-path-parms-title} == {api-path-parms-title}
`<model_id>`:: `<model_id>`::
(Optional, string) (Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
@ -56,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
== {api-query-parms-title} == {api-query-parms-title}
`allow_no_match`:: `allow_no_match`::
(Optional, boolean) (Optional, boolean)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models] include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models]
`decompress_definition`:: `decompress_definition`::
(Optional, boolean) (Optional, boolean)
Specifies whether the included model definition should be returned as a JSON map Specifies whether the included model definition should be returned as a JSON map
(`true`) or in a custom compressed format (`false`). Defaults to `true`. (`true`) or in a custom compressed format (`false`). Defaults to `true`.
`for_export`:: `for_export`::
@ -71,17 +72,21 @@ retrieval. This allows the model to be in an acceptable format to be retrieved
and then added to another cluster. Default is false. and then added to another cluster. Default is false.
`from`:: `from`::
(Optional, integer) (Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models] include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models]
`include_model_definition`:: `include`::
(Optional, boolean) (Optional, string)
Specifies whether the model definition is returned in the response. Defaults to A comma delimited string of optional fields to include in the response body.
`false`. When `true`, only a single model must match the ID patterns provided. Valid options are:
Otherwise, a bad request is returned. - `definition`: Includes the model definition
- `total_feature_importance`: Includes the total feature importance for the
training data set. This field is available in the `metadata` field in the
response body.
Default is empty, indicating including no optional fields.
`size`:: `size`::
(Optional, integer) (Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models] include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models]
`tags`:: `tags`::
@ -94,7 +99,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags]
`trained_model_configs`:: `trained_model_configs`::
(array) (array)
An array of trained model resources, which are sorted by the `model_id` value in An array of trained model resources, which are sorted by the `model_id` value in
ascending order. ascending order.
+ +
.Properties of trained model resources .Properties of trained model resources
@ -132,8 +137,86 @@ The license level of the trained model.
`metadata`::: `metadata`:::
(object) (object)
An object containing metadata about the trained model. For example, models An object containing metadata about the trained model. For example, models
created by {dfanalytics} contain `analysis_config` and `input` objects. created by {dfanalytics} contain `analysis_config` and `input` objects.
.Properties of metadata
[%collapsible%open]
=====
`total_feature_importance`:::
(array)
An array of the total feature importance for each feature used from
the training data set. This array of objects is returned if {dfanalytics} trained
the model and the request includes `total_feature_importance` in the `include`
request parameter.
+
.Properties of total feature importance
[%collapsible%open]
======
`feature_name`:::
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name]
`importance`:::
(object)
A collection of feature importance statistics related to the training data set for this particular feature.
+
.Properties of feature importance
[%collapsible%open]
=======
`mean_magnitude`:::
(double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
`max`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
`min`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
=======
`classes`:::
(array)
If the trained model is a classification model, feature importance statistics are gathered
per target class value.
+
.Properties of class feature importance
[%collapsible%open]
=======
`class_name`:::
(string)
The target class value. Could be a string, boolean, or number.
`importance`:::
(object)
A collection of feature importance statistics related to the training data set for this particular feature.
+
.Properties of feature importance
[%collapsible%open]
========
`mean_magnitude`:::
(double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
`max`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
`min`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
========
=======
======
=====
`model_id`::: `model_id`:::
(string) (string)
@ -152,13 +235,13 @@ The {es} version number in which the trained model was created.
== {api-response-codes-title} == {api-response-codes-title}
`400`:: `400`::
If `include_model_definition` is `true`, this code indicates that more than If `include_model_definition` is `true`, this code indicates that more than
one models match the ID pattern. one models match the ID pattern.
`404` (Missing resources):: `404` (Missing resources)::
If `allow_no_match` is `false`, this code indicates that there are no If `allow_no_match` is `false`, this code indicates that there are no
resources that match the request or only partial matches for the request. resources that match the request or only partial matches for the request.
[[ml-get-inference-example]] [[ml-get-inference-example]]
== {api-examples-title} == {api-examples-title}

View File

@ -780,6 +780,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
used to train the model, which defaults to `<dependent_variable>_prediction`. used to train the model, which defaults to `<dependent_variable>_prediction`.
end::inference-config-results-field-processor[] end::inference-config-results-field-processor[]
tag::inference-metadata-feature-importance-feature-name[]
The training feature name for which this importance was calculated.
end::inference-metadata-feature-importance-feature-name[]
tag::inference-metadata-feature-importance-magnitude[]
The average magnitude of this feature across all the training data.
This value is the average of the absolute values of the importance
for this feature.
end::inference-metadata-feature-importance-magnitude[]
tag::inference-metadata-feature-importance-max[]
The maximum importance value across all the training data for this
feature.
end::inference-metadata-feature-importance-max[]
tag::inference-metadata-feature-importance-min[]
The minimum importance value across all the training data for this
feature.
end::inference-metadata-feature-importance-min[]
tag::influencers[] tag::influencers[]
A comma separated list of influencer field names. Typically these can be the by, A comma separated list of influencer field names. Typically these can be the by,
over, or partition fields that are used in the detector configuration. You might over, or partition fields that are used in the detector configuration. You might

View File

@ -10,15 +10,19 @@ import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> { public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
@ -32,23 +36,60 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
public static class Request extends AbstractGetResourcesRequest { public static class Request extends AbstractGetResourcesRequest {
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition"); static final String DEFINITION = "definition";
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
private static final Set<String> KNOWN_INCLUDES;
static {
HashSet<String> includes = new HashSet<>(2, 1.0f);
includes.add(DEFINITION);
includes.add(TOTAL_FEATURE_IMPORTANCE);
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
}
public static final ParseField INCLUDE = new ParseField("include");
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags"); public static final ParseField TAGS = new ParseField("tags");
private final boolean includeModelDefinition; private final Set<String> includes;
private final List<String> tags; private final List<String> tags;
@Deprecated
public Request(String id, boolean includeModelDefinition, List<String> tags) { public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id); setResourceId(id);
setAllowNoResources(true); setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
this.tags = tags == null ? Collections.emptyList() : tags; this.tags = tags == null ? Collections.emptyList() : tags;
if (includeModelDefinition) {
this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
} else {
this.includes = Collections.emptySet();
}
}
public Request(String id, List<String> tags, Set<String> includes) {
setResourceId(id);
setAllowNoResources(true);
this.tags = tags == null ? Collections.emptyList() : tags;
this.includes = includes == null ? Collections.emptySet() : includes;
Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
if (unknownIncludes.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"unknown [include] parameters {}. Valid options are {}",
unknownIncludes,
KNOWN_INCLUDES);
}
} }
public Request(StreamInput in) throws IOException { public Request(StreamInput in) throws IOException {
super(in); super(in);
this.includeModelDefinition = in.readBoolean(); if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.includes = in.readSet(StreamInput::readString);
} else {
Set<String> includes = new HashSet<>();
if (in.readBoolean()) {
includes.add(DEFINITION);
}
this.includes = includes;
}
if (in.getVersion().onOrAfter(Version.V_7_7_0)) { if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.tags = in.readStringList(); this.tags = in.readStringList();
} else { } else {
@ -62,7 +103,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
} }
public boolean isIncludeModelDefinition() { public boolean isIncludeModelDefinition() {
return includeModelDefinition; return this.includes.contains(DEFINITION);
}
public boolean isIncludeTotalFeatureImportance() {
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
} }
public List<String> getTags() { public List<String> getTags() {
@ -72,7 +117,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out); super.writeTo(out);
out.writeBoolean(includeModelDefinition); if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeCollection(this.includes, StreamOutput::writeString);
} else {
out.writeBoolean(this.includes.contains(DEFINITION));
}
if (out.getVersion().onOrAfter(Version.V_7_7_0)) { if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeStringCollection(tags); out.writeStringCollection(tags);
} }
@ -80,7 +129,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition, tags); return Objects.hash(super.hashCode(), includes, tags);
} }
@Override @Override
@ -92,7 +141,18 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
return false; return false;
} }
Request other = (Request) obj; Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags); return super.equals(obj) && Objects.equals(includes, other.includes) && Objects.equals(tags, other.tags);
}
@Override
public String toString() {
return "Request{" +
"includes=" + includes +
", tags=" + tags +
", page=" + getPageParams() +
", id=" + getResourceId() +
", allow_missing=" + isAllowNoResources() +
'}';
} }
} }

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.elasticsearch.xpack.core.ml.utils.MlStrings;
@ -39,6 +40,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.action.ValidateActions.addValidationError;
@ -51,6 +53,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1; public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String FOR_EXPORT = "for_export"; public static final String FOR_EXPORT = "for_export";
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
@ -419,7 +423,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition); this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
this.description = config.getDescription(); this.description = config.getDescription();
this.tags = config.getTags(); this.tags = config.getTags();
this.metadata = config.getMetadata(); this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata());
this.input = config.getInput(); this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations; this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory; this.estimatedHeapMemory = config.estimatedHeapMemory;
@ -471,6 +475,18 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this; return this;
} }
public Builder setFeatureImportance(List<TotalFeatureImportance> totalFeatureImportance) {
if (totalFeatureImportance == null) {
return this;
}
if (this.metadata == null) {
this.metadata = new HashMap<>();
}
this.metadata.put(TOTAL_FEATURE_IMPORTANCE,
totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()));
return this;
}
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) { public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
if (definition == null) { if (definition == null) {
return this; return this;
@ -627,6 +643,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
ESTIMATED_OPERATIONS.getPreferredName(), ESTIMATED_OPERATIONS.getPreferredName(),
validationException); validationException);
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException); validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
if (metadata != null) {
validationException = checkIllegalSetting(
metadata.get(TOTAL_FEATURE_IMPORTANCE),
METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE,
validationException);
}
} }
if (validationException != null) { if (validationException != null) {

View File

@ -20,8 +20,11 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
public class TotalFeatureImportance implements ToXContentObject, Writeable { public class TotalFeatureImportance implements ToXContentObject, Writeable {
@ -81,16 +84,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); return builder.map(asMap());
builder.field(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
builder.field(IMPORTANCE.getPreferredName(), importance);
}
if (classImportances.isEmpty() == false) {
builder.field(CLASSES.getPreferredName(), classImportances);
}
builder.endObject();
return builder;
} }
@Override @Override
@ -103,6 +97,18 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
&& Objects.equals(classImportances, that.classImportances); && Objects.equals(classImportances, that.classImportances);
} }
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
}
if (classImportances.isEmpty() == false) {
map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList()));
}
return map;
}
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(featureName, importance, classImportances); return Objects.hash(featureName, importance, classImportances);
@ -165,12 +171,15 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); return builder.map(asMap());
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); }
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max); private Map<String, Object> asMap() {
builder.endObject(); Map<String, Object> map = new LinkedHashMap<>();
return builder; map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
map.put(MIN.getPreferredName(), min);
map.put(MAX.getPreferredName(), max);
return map;
} }
} }
@ -229,11 +238,14 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); return builder.map(asMap());
builder.field(CLASS_NAME.getPreferredName(), className); }
builder.field(IMPORTANCE.getPreferredName(), importance);
builder.endObject(); private Map<String, Object> asMap() {
return builder; Map<String, Object> map = new LinkedHashMap<>();
map.put(CLASS_NAME.getPreferredName(), className);
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
return map;
} }
@Override @Override

View File

@ -53,6 +53,10 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
return NAME + "-" + modelId; return NAME + "-" + modelId;
} }
public static String modelId(String docId) {
return docId.substring(NAME.length() + 1);
}
private final List<TotalFeatureImportance> totalFeatureImportances; private final List<TotalFeatureImportance> totalFeatureImportances;
private final String modelId; private final String modelId;

View File

@ -103,7 +103,7 @@ public final class Messages {
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]"; public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
public static final String INFERENCE_CANNOT_DELETE_MODEL = public static final String INFERENCE_CANNOT_DELETE_MODEL =
"Unable to delete model [{0}]"; "Unable to delete model [{0}]";
public static final String MODEL_DEFINITION_TRUNCATED = public static final String MODEL_DEFINITION_TRUNCATED =

View File

@ -5,19 +5,28 @@
*/ */
package org.elasticsearch.xpack.core.ml.action; package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> { import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
@Override @Override
protected Request createTestInstance() { protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20), Request request = new Request(randomAlphaOfLength(20),
randomBoolean(),
randomBoolean() ? null : randomBoolean() ? null :
randomList(10, () -> randomAlphaOfLength(10))); randomList(10, () -> randomAlphaOfLength(10)),
randomBoolean() ? null :
Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
.limit(4)
.collect(Collectors.toSet()));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request; return request;
} }
@ -26,4 +35,22 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
protected Writeable.Reader<Request> instanceReader() { protected Writeable.Reader<Request> instanceReader() {
return Request::new; return Request::new;
} }
@Override
protected Request mutateInstanceForVersion(Request instance, Version version) {
if (version.before(Version.V_7_10_0)) {
Set<String> includes = new HashSet<>();
if (instance.isIncludeModelDefinition()) {
includes.add(Request.DEFINITION);
}
Request request = new Request(
instance.getResourceId(),
version.before(Version.V_7_7_0) ? null : instance.getTags(),
includes);
request.setPageParams(instance.getPageParams());
request.setAllowNoResources(instance.isAllowNoResources());
return request;
}
return instance;
}
} }

View File

@ -42,11 +42,13 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.Matchers.startsWith;
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
@ -95,19 +97,21 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet(); Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L)); assertThat(ids.v1(), equalTo(1L));
String inferenceModelId = ids.v2().iterator().next();
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>(); PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture); trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet(); TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
PlainActionFuture<TrainedModelMetadata> getTrainedMetadataFuture = new PlainActionFuture<>(); PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture); trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet(); TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
assertThat(storedMetadata.getModelId(), startsWith(modelId)); assertThat(storedMetadata.getModelId(), startsWith(modelId));
assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances())); assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
} }

View File

@ -90,7 +90,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
getConfigHolder.get().ensureParsedDefinition(xContentRegistry()); getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(config)); assertThat(getConfigHolder.get(), equalTo(config));
@ -121,7 +124,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get(), is(nullValue())); assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder); blockingCall(listener ->
trainedModelProvider.getTrainedModel(modelId, false, false, listener),
getConfigHolder,
exceptionHolder);
getConfigHolder.get().ensureParsedDefinition(xContentRegistry()); getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition)); assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition));
@ -132,7 +138,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
String modelId = "test-get-missing-trained-model-config"; String modelId = "test-get-missing-trained-model-config";
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>(); AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
@ -154,7 +163,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
.actionGet(); .actionGet();
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
@ -193,7 +205,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
} }
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue())); assertThat(getConfigHolder.get(), is(nullValue()));
assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
@ -238,7 +253,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
} }
} }
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>(); AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue())); assertThat(getConfigHolder.get(), is(nullValue()));
assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));

View File

@ -57,15 +57,25 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
} }
if (request.isIncludeModelDefinition()) { if (request.isIncludeModelDefinition()) {
provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap( provider.getTrainedModel(
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), totalAndIds.v2().iterator().next(),
listener::onFailure true,
)); request.isIncludeTotalFeatureImportance(),
ActionListener.wrap(
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
listener::onFailure
)
);
} else { } else {
provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap( provider.getTrainedModels(
configs -> listener.onResponse(responseBuilder.setModels(configs).build()), totalAndIds.v2(),
listener::onFailure request.isAllowNoResources(),
)); request.isIncludeTotalFeatureImportance(),
ActionListener.wrap(
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
listener::onFailure
)
);
} }
}, },
listener::onFailure listener::onFailure

View File

@ -82,7 +82,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
responseBuilder.setLicensed(true); responseBuilder.setLicensed(true);
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener); this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
} else { } else {
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap( trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
trainedModelConfig -> { trainedModelConfig -> {
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel())); responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) { if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {

View File

@ -270,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener {
} }
private void loadModel(String modelId, Consumer consumer) { private void loadModel(String modelId, Consumer consumer) {
provider.getTrainedModel(modelId, false, ActionListener.wrap( provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
trainedModelConfig -> { trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap( provider.getTrainedModelForInference(modelId, ActionListener.wrap(
@ -306,7 +306,7 @@ public class ModelLoadingService implements ClusterStateListener {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline // by a simulated pipeline
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
provider.getTrainedModel(modelId, false, ActionListener.wrap( provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
trainedModelConfig -> { trainedModelConfig -> {
// Verify we can pull the model into memory without causing OOM // Verify we can pull the model into memory without causing OOM
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
@ -434,7 +434,7 @@ public class ModelLoadingService implements ClusterStateListener {
logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]", logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
notification.getValue().model.getModelId())); notification.getValue().model.getModelId()));
// If the model is no longer referenced, flush the stats to persist as soon as possible // If the model is no longer referenced, flush the stats to persist as soon as possible
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
} finally { } finally {

View File

@ -89,9 +89,11 @@ import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -235,14 +237,14 @@ public class TrainedModelProvider {
)); ));
} }
public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) { public void getTrainedModelMetadata(Collection<String> modelIds, ActionListener<Map<String, TrainedModelMetadata>> listener) {
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.boolQuery() .boolQuery()
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
TrainedModelMetadata.NAME)))) TrainedModelMetadata.NAME))))
.setSize(1) .setSize(10_000)
// First find the latest index // First find the latest index
.addSort("_index", SortOrder.DESC) .addSort("_index", SortOrder.DESC)
.request(); .request();
@ -250,18 +252,20 @@ public class TrainedModelProvider {
searchResponse -> { searchResponse -> {
if (searchResponse.getHits().getHits().length == 0) { if (searchResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException( listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
return; return;
} }
List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(), HashMap<String, TrainedModelMetadata> map = new HashMap<>();
modelId, for (SearchHit hit : searchResponse.getHits().getHits()) {
this::parseMetadataLenientlyFromSource); String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId()));
listener.onResponse(metadataList.get(0)); map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId));
}
listener.onResponse(map);
}, },
e -> { e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
listener.onFailure(new ResourceNotFoundException( listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
return; return;
} }
listener.onFailure(e); listener.onFailure(e);
@ -371,7 +375,7 @@ public class TrainedModelProvider {
// TODO Change this when we get more than just langIdent stored // TODO Change this when we get more than just langIdent stored
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try { try {
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry); TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
listener.onResponse( listener.onResponse(
InferenceDefinition.builder() InferenceDefinition.builder()
@ -434,18 +438,50 @@ public class TrainedModelProvider {
)); ));
} }
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) { public void getTrainedModel(final String modelId,
final boolean includeDefinition,
final boolean includeTotalFeatureImportance,
final ActionListener<TrainedModelConfig> finalListener) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try { try {
listener.onResponse(loadModelFromResource(modelId, includeDefinition == false)); finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
return; return;
} catch (ElasticsearchException ex) { } catch (ElasticsearchException ex) {
listener.onFailure(ex); finalListener.onFailure(ex);
return; return;
} }
} }
ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
modelBuilder -> {
if (includeTotalFeatureImportance == false) {
finalListener.onResponse(modelBuilder.build());
return;
}
this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap(
metadata -> {
TrainedModelMetadata modelMetadata = metadata.get(modelId);
if (modelMetadata != null) {
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
}
finalListener.onResponse(modelBuilder.build());
},
failure -> {
// total feature importance is not necessary for a model to be valid
// we shouldn't fail if it is not found
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
finalListener.onResponse(modelBuilder.build());
return;
}
finalListener.onFailure(failure);
}
));
},
finalListener::onFailure
);
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery() .idsQuery()
.addIds(modelId)); .addIds(modelId));
@ -483,11 +519,11 @@ public class TrainedModelProvider {
try { try {
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource); builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
} catch (ResourceNotFoundException ex) { } catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException( getTrainedModelListener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return; return;
} catch (Exception ex) { } catch (Exception ex) {
listener.onFailure(ex); getTrainedModelListener.onFailure(ex);
return; return;
} }
@ -500,22 +536,22 @@ public class TrainedModelProvider {
String compressedString = getDefinitionFromDocs(docs, modelId); String compressedString = getDefinitionFromDocs(docs, modelId);
builder.setDefinitionFromString(compressedString); builder.setDefinitionFromString(compressedString);
} catch (ElasticsearchException elasticsearchException) { } catch (ElasticsearchException elasticsearchException) {
listener.onFailure(elasticsearchException); getTrainedModelListener.onFailure(elasticsearchException);
return; return;
} }
} catch (ResourceNotFoundException ex) { } catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException( getTrainedModelListener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return; return;
} catch (Exception ex) { } catch (Exception ex) {
listener.onFailure(ex); getTrainedModelListener.onFailure(ex);
return; return;
} }
} }
listener.onResponse(builder.build()); getTrainedModelListener.onResponse(builder);
}, },
listener::onFailure getTrainedModelListener::onFailure
); );
executeAsyncWithOrigin(client, executeAsyncWithOrigin(client,
@ -532,7 +568,10 @@ public class TrainedModelProvider {
* This does no expansion on the ids. * This does no expansion on the ids.
* It assumes that there are fewer than 10k. * It assumes that there are fewer than 10k.
*/ */
public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) { public void getTrainedModels(Set<String> modelIds,
boolean allowNoResources,
boolean includeTotalFeatureImportance,
final ActionListener<List<TrainedModelConfig>> finalListener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
@ -541,23 +580,63 @@ public class TrainedModelProvider {
.setQuery(queryBuilder) .setQuery(queryBuilder)
.setSize(modelIds.size()) .setSize(modelIds.size())
.request(); .request();
List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size()); List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE); Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds); Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
for(String modelId : modelsAsResource) { for(String modelId : modelsAsResource) {
try { try {
configs.add(loadModelFromResource(modelId, true)); configs.add(loadModelFromResource(modelId, true));
} catch (ElasticsearchException ex) { } catch (ElasticsearchException ex) {
listener.onFailure(ex); finalListener.onFailure(ex);
return; return;
} }
} }
if (modelsInIndex.isEmpty()) { if (modelsInIndex.isEmpty()) {
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); finalListener.onResponse(configs.stream()
listener.onResponse(configs); .map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList()));
return; return;
} }
ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
modelBuilders -> {
if (includeTotalFeatureImportance == false) {
finalListener.onResponse(modelBuilders.stream()
.map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList()));
return;
}
this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
metadata ->
finalListener.onResponse(modelBuilders.stream()
.map(builder -> {
TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
if (modelMetadata != null) {
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
}
return builder.build();
})
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList())),
failure -> {
// total feature importance is not necessary for a model to be valid
// we shouldn't fail if it is not found
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
finalListener.onResponse(modelBuilders.stream()
.map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList()));
return;
}
finalListener.onFailure(failure);
}
));
},
finalListener::onFailure
);
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap( ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
searchResponse -> { searchResponse -> {
Set<String> observedIds = new HashSet<>( Set<String> observedIds = new HashSet<>(
@ -568,12 +647,12 @@ public class TrainedModelProvider {
try { try {
if (observedIds.contains(searchHit.getId()) == false) { if (observedIds.contains(searchHit.getId()) == false) {
configs.add( configs.add(
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build() parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId())
); );
observedIds.add(searchHit.getId()); observedIds.add(searchHit.getId());
} }
} catch (IOException ex) { } catch (IOException ex) {
listener.onFailure( getTrainedModelListener.onFailure(
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId())); ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
return; return;
} }
@ -583,14 +662,13 @@ public class TrainedModelProvider {
// Otherwise, treat it as if it was never expanded to begin with. // Otherwise, treat it as if it was never expanded to begin with.
Set<String> missingConfigs = Sets.difference(modelIds, observedIds); Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
if (missingConfigs.isEmpty() == false && allowNoResources == false) { if (missingConfigs.isEmpty() == false && allowNoResources == false) {
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
return; return;
} }
// Ensure sorted even with the injection of locally resourced models // Ensure sorted even with the injection of locally resourced models
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); getTrainedModelListener.onResponse(configs);
listener.onResponse(configs);
}, },
listener::onFailure getTrainedModelListener::onFailure
); );
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
@ -639,7 +717,7 @@ public class TrainedModelProvider {
foundResourceIds = new HashSet<>(); foundResourceIds = new HashSet<>();
for(String resourceId : matchedResourceIds) { for(String resourceId : matchedResourceIds) {
// Does the model as a resource have all the tags? // Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) {
foundResourceIds.add(resourceId); foundResourceIds.add(resourceId);
} }
} }
@ -833,7 +911,7 @@ public class TrainedModelProvider {
return QueryBuilders.constantScoreQuery(boolQueryBuilder); return QueryBuilders.constantScoreQuery(boolQueryBuilder);
} }
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) { TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) {
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT); URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
if (resource == null) { if (resource == null) {
logger.error("[{}] presumed stored as a resource but not found", modelId); logger.error("[{}] presumed stored as a resource but not found", modelId);
@ -848,7 +926,7 @@ public class TrainedModelProvider {
if (nullOutDefinition) { if (nullOutDefinition) {
builder.clearDefinition(); builder.clearDefinition();
} }
return builder.build(); return builder;
} catch (IOException ioEx) { } catch (IOException ioEx) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx); logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId); throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);

View File

@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -56,12 +57,17 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
if (Strings.isNullOrEmpty(modelId)) { if (Strings.isNullOrEmpty(modelId)) {
modelId = Metadata.ALL; modelId = Metadata.ALL;
} }
boolean includeModelDefinition = restRequest.paramAsBoolean(
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
false
);
List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY)); List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags); Set<String> includes = new HashSet<>(
asList(
restRequest.paramAsStringArray(
GetTrainedModelsAction.Request.INCLUDE.getPreferredName(),
Strings.EMPTY_ARRAY)));
final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ?
new GetTrainedModelsAction.Request(modelId,
restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false),
tags) :
new GetTrainedModelsAction.Request(modelId, tags, includes);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));

View File

@ -437,9 +437,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
// the loading occurred or which models are currently in the cache due to evictions. // the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three // Verify that we have at least loaded all three
assertBusy(() -> { assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
}); });
assertBusy(() -> { assertBusy(() -> {
assertThat(circuitBreaker.getUsed(), equalTo(10L)); assertThat(circuitBreaker.getUsed(), equalTo(10L));
@ -553,10 +553,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
listener.onResponse(trainedModelConfig); listener.onResponse(trainedModelConfig);
return null; return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -564,20 +564,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
if (randomBoolean()) { if (randomBoolean()) {
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
listener.onFailure(new ResourceNotFoundException( listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return null; return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
} else { } else {
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L); when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
listener.onResponse(trainedModelConfig); listener.onResponse(trainedModelConfig);
return null; return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];

View File

@ -57,14 +57,14 @@ public class TrainedModelProviderTests extends ESTestCase {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) { for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>(); PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, true, future); trainedModelProvider.getTrainedModel(modelId, true, false, future);
TrainedModelConfig configWithDefinition = future.actionGet(); TrainedModelConfig configWithDefinition = future.actionGet();
assertThat(configWithDefinition.getModelId(), equalTo(modelId)); assertThat(configWithDefinition.getModelId(), equalTo(modelId));
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue()))); assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>(); PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition); trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet(); TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId)); assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));

View File

@ -33,7 +33,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>(); PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls // Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future); trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
TrainedModelConfig config = future.actionGet(); TrainedModelConfig config = future.actionGet();
config.ensureParsedDefinition(xContentRegistry()); config.ensureParsedDefinition(xContentRegistry());

View File

@ -34,11 +34,10 @@
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)", "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
"default":true "default":true
}, },
"include_model_definition":{ "include":{
"type":"boolean", "type":"string",
"required":false, "required":false,
"description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.", "description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none."
"default":false
}, },
"decompress_definition":{ "decompress_definition":{
"type":"boolean", "type":"boolean",

View File

@ -1,6 +1,24 @@
setup: setup:
- skip: - skip:
features: headers features:
- headers
- allowed_warnings
- do:
allowed_warnings:
- "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template"
headers:
Content-Type: application/json
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_metadata-a-regression-model-0
index: .ml-inference-000003
body:
model_id: "a-regression-model-0"
doc_type: "trained_model_metadata"
total_feature_importance:
- { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
- { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
- do: - do:
headers: headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
@ -548,6 +566,20 @@ setup:
- match: { count: 12 } - match: { count: 12 }
- match: { trained_model_configs.0.model_id: "a-regression-model-1" } - match: { trained_model_configs.0.model_id: "a-regression-model-1" }
--- ---
"Test get models with include total feature importance":
- do:
ml.get_trained_models:
model_id: "a-regression-model-*"
include: "total_feature_importance"
- match: { count: 2 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- is_true: trained_model_configs.0.metadata.total_feature_importance
- length: { trained_model_configs.0.metadata.total_feature_importance: 2 }
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
- is_false: trained_model_configs.1.metadata.total_feature_importance
---
"Test delete given unused trained model": "Test delete given unused trained model":
- do: - do:
ml.delete_trained_model: ml.delete_trained_model:
@ -824,7 +856,7 @@ setup:
ml.get_trained_models: ml.get_trained_models:
model_id: "a-regression-model-1" model_id: "a-regression-model-1"
for_export: true for_export: true
include_model_definition: true include: "definition"
decompress_definition: false decompress_definition: false
- match: { trained_model_configs.0.description: "empty model for tests" } - match: { trained_model_configs.0.description: "empty model for tests" }