[7.x] [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) (#62620)

* [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922)

Adds new flag include to the get trained models API
The flag initially has two valid values: definition, total_feature_importance.
Consequently, the old include_model_definition flag is now deprecated.
When total_feature_importance is included, the total_feature_importance field is included in the model metadata object.
Including definition is the same as previously setting include_model_definition=true.

* fixing test

* Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
This commit is contained in:
Benjamin Trent 2020-09-18 10:07:35 -04:00 committed by GitHub
parent e1a4a3073a
commit e163559e4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 833 additions and 162 deletions

View File

@ -779,9 +779,9 @@ final class MLRequestConverters {
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
}
if (getTrainedModelsRequest.getIncludeDefinition() != null) {
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
if (getTrainedModelsRequest.getIncludes().isEmpty() == false) {
params.putParam(GetTrainedModelsRequest.INCLUDE,
Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes()));
}
if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));

View File

@ -26,21 +26,26 @@ import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.Nullable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
public class GetTrainedModelsRequest implements Validatable {
private static final String DEFINITION = "definition";
private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String FOR_EXPORT = "for_export";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags";
public static final String INCLUDE = "include";
private final List<String> ids;
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Set<String> includes = new HashSet<>();
private Boolean decompressDefinition;
private Boolean forExport;
private PageParams pageParams;
@ -86,19 +91,32 @@ public class GetTrainedModelsRequest implements Validatable {
return this;
}
public Boolean getIncludeDefinition() {
return includeDefinition;
public Set<String> getIncludes() {
return Collections.unmodifiableSet(includes);
}
public GetTrainedModelsRequest includeDefinition() {
this.includes.add(DEFINITION);
return this;
}
public GetTrainedModelsRequest includeTotalFeatureImportance() {
this.includes.add(TOTAL_FEATURE_IMPORTANCE);
return this;
}
/**
* Whether to include the full model definition.
*
* The full model definition can be very large.
*
* @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()}
* @param includeDefinition If {@code true}, the definition is included.
*/
@Deprecated
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
this.includeDefinition = includeDefinition;
if (includeDefinition != null && includeDefinition) {
return this.includeDefinition();
}
return this;
}
@ -173,13 +191,13 @@ public class GetTrainedModelsRequest implements Validatable {
return Objects.equals(ids, other.ids)
&& Objects.equals(allowNoMatch, other.allowNoMatch)
&& Objects.equals(decompressDefinition, other.decompressDefinition)
&& Objects.equals(includeDefinition, other.includeDefinition)
&& Objects.equals(includes, other.includes)
&& Objects.equals(forExport, other.forExport)
&& Objects.equals(pageParams, other.pageParams);
}
@Override
public int hashCode() {
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport);
}
}

View File

@ -0,0 +1,208 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TotalFeatureImportance implements ToXContentObject {
private static final String NAME = "total_feature_importance";
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ParseField CLASSES = new ParseField("classes");
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
public static final ParseField MIN = new ParseField("min");
public static final ParseField MAX = new ParseField("max");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
}
public static TotalFeatureImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public final String featureName;
public final Importance importance;
public final List<ClassImportance> classImportances;
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
this.featureName = featureName;
this.importance = importance;
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
builder.field(IMPORTANCE.getPreferredName(), importance);
}
if (classImportances.isEmpty() == false) {
builder.field(CLASSES.getPreferredName(), classImportances);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TotalFeatureImportance that = (TotalFeatureImportance) o;
return Objects.equals(that.importance, importance)
&& Objects.equals(featureName, that.featureName)
&& Objects.equals(classImportances, that.classImportances);
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportances);
}
public static class Importance implements ToXContentObject {
private static final String NAME = "importance";
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
static {
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
}
private final double meanMagnitude;
private final double min;
private final double max;
public Importance(double meanMagnitude, double min, double max) {
this.meanMagnitude = meanMagnitude;
this.min = min;
this.max = max;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Importance that = (Importance) o;
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
Double.compare(that.min, min) == 0 &&
Double.compare(that.max, max) == 0;
}
@Override
public int hashCode() {
return Objects.hash(meanMagnitude, min, max);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max);
builder.endObject();
return builder;
}
}
public static class ClassImportance implements ToXContentObject {
private static final String NAME = "total_class_importance";
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ParseField IMPORTANCE = new ParseField("importance");
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new ClassImportance(a[0], (Importance)a[1]));
static {
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
}
public static ClassImportance fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
public final Object className;
public final Importance importance;
ClassImportance(Object className, Importance importance) {
this.className = className;
this.importance = importance;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(IMPORTANCE.getPreferredName(), importance);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassImportance that = (ClassImportance) o;
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
}
@Override
public int hashCode() {
return Objects.hash(className, importance);
}
}
}

View File

@ -894,7 +894,7 @@ public class MLRequestConvertersTests extends ESTestCase {
GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
.setAllowNoMatch(false)
.setDecompressDefinition(true)
.setIncludeDefinition(false)
.includeDefinition()
.setTags("tag1", "tag2")
.setPageParams(new PageParams(100, 300));
@ -908,7 +908,7 @@ public class MLRequestConvertersTests extends ESTestCase {
hasEntry("allow_no_match", "false"),
hasEntry("decompress_definition", "true"),
hasEntry("tags", "tag1,tag2"),
hasEntry("include_model_definition", "false")
hasEntry("include", "definition")
));
assertNull(request.getEntity());
}

View File

@ -2257,7 +2257,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
{
GetTrainedModelsResponse getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true),
new GetTrainedModelsRequest(modelIdPrefix + 0)
.setDecompressDefinition(true)
.includeDefinition()
.includeTotalFeatureImportance(),
machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync);
@ -2268,7 +2271,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true),
new GetTrainedModelsRequest(modelIdPrefix + 0)
.setDecompressDefinition(false)
.includeTotalFeatureImportance()
.includeDefinition(),
machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync);
@ -2279,7 +2285,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
getTrainedModelsResponse = execute(
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false),
new GetTrainedModelsRequest(modelIdPrefix + 0)
.setDecompressDefinition(false),
machineLearningClient::getTrainedModels,
machineLearningClient::getTrainedModelsAsync);
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));

View File

@ -3694,11 +3694,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
// tag::get-trained-models-request
GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
.setPageParams(new PageParams(0, 1)) // <2>
.setIncludeDefinition(false) // <3>
.setDecompressDefinition(false) // <4>
.setAllowNoMatch(true) // <5>
.setTags("regression") // <6>
.setForExport(false); // <7>
.includeDefinition() // <3>
.includeTotalFeatureImportance() // <4>
.setDecompressDefinition(false) // <5>
.setAllowNoMatch(true) // <6>
.setTags("regression") // <7>
.setForExport(false); // <8>
// end::get-trained-models-request
request.setTags((List<String>)null);

View File

@ -0,0 +1,63 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
public static TotalFeatureImportance randomInstance() {
return new TotalFeatureImportance(
randomAlphaOfLength(10),
randomBoolean() ? null : randomImportance(),
randomBoolean() ?
null :
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
.limit(randomIntBetween(1, 10))
.collect(Collectors.toList())
);
}
private static TotalFeatureImportance.Importance randomImportance() {
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
}
@Override
protected TotalFeatureImportance createTestInstance() {
return randomInstance();
}
@Override
protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
return TotalFeatureImportance.fromXContent(parser);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -22,26 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models.
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-request]
--------------------------------------------------
<1> Constructing a new GET request referencing an existing Trained Model
<1> Constructing a new GET request referencing an existing trained model
<2> Set the paging parameters
<3> Indicate if the complete model definition should be included
<4> Should the definition be fully decompressed on GET
<5> Allow empty response if no Trained Models match the provided ID patterns.
If false, an error will be thrown if no Trained Models match the
<4> Indicate if the total feature importance for the features used in training
should be included in the model `metadata` field.
<5> Should the definition be fully decompressed on GET
<6> Allow empty response if no trained models match the provided ID patterns.
If false, an error will be thrown if no trained models match the
ID patterns.
<6> An optional list of tags used to narrow the model search. A Trained Model
<7> An optional list of tags used to narrow the model search. A trained model
can have many tags or none. The trained models in the response will
contain all the provided tags.
<7> Optional boolean value indicating if certain fields should be removed on
retrieval. This is useful for getting the trained model in a format that
can then be put into another cluster.
<8> Optional boolean value for requesting the trained model in a format that can
then be put into another cluster. Certain fields that can only be set when
the model is imported are removed.
include::../execution.asciidoc[]
[id="{upid}-{api}-response"]
==== Response
The returned +{response}+ contains the requested Trained Model.
The returned +{response}+ contains the requested trained model.
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------

View File

@ -29,18 +29,19 @@ experimental[]
[[ml-get-inference-prereq]]
== {api-prereq-title}
Required privileges which should be added to a custom role:
If the {es} {security-features} are enabled, you must have the following
privileges:
* cluster: `monitor_ml`
For more information, see <<security-privileges>> and
For more information, see <<security-privileges>> and
{ml-docs-setup-privileges}.
[[ml-get-inference-desc]]
== {api-description-title}
You can get information for multiple trained models in a single API request by
You can get information for multiple trained models in a single API request by
using a comma-separated list of model IDs or a wildcard expression.
@ -48,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression.
== {api-path-parms-title}
`<model_id>`::
(Optional, string)
(Optional, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
@ -56,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
== {api-query-parms-title}
`allow_no_match`::
(Optional, boolean)
(Optional, boolean)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models]
`decompress_definition`::
(Optional, boolean)
Specifies whether the included model definition should be returned as a JSON map
Specifies whether the included model definition should be returned as a JSON map
(`true`) or in a custom compressed format (`false`). Defaults to `true`.
`for_export`::
@ -71,17 +72,21 @@ retrieval. This allows the model to be in an acceptable format to be retrieved
and then added to another cluster. Default is false.
`from`::
(Optional, integer)
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models]
`include_model_definition`::
(Optional, boolean)
Specifies whether the model definition is returned in the response. Defaults to
`false`. When `true`, only a single model must match the ID patterns provided.
Otherwise, a bad request is returned.
`include`::
(Optional, string)
A comma delimited string of optional fields to include in the response body.
Valid options are:
- `definition`: Includes the model definition
- `total_feature_importance`: Includes the total feature importance for the
training data set. This field is available in the `metadata` field in the
response body.
Default is empty, indicating including no optional fields.
`size`::
(Optional, integer)
(Optional, integer)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models]
`tags`::
@ -94,7 +99,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags]
`trained_model_configs`::
(array)
An array of trained model resources, which are sorted by the `model_id` value in
An array of trained model resources, which are sorted by the `model_id` value in
ascending order.
+
.Properties of trained model resources
@ -132,8 +137,86 @@ The license level of the trained model.
`metadata`:::
(object)
An object containing metadata about the trained model. For example, models
An object containing metadata about the trained model. For example, models
created by {dfanalytics} contain `analysis_config` and `input` objects.
.Properties of metadata
[%collapsible%open]
=====
`total_feature_importance`:::
(array)
An array of the total feature importance for each feature used from
the training data set. This array of objects is returned if {dfanalytics} trained
the model and the request includes `total_feature_importance` in the `include`
request parameter.
+
.Properties of total feature importance
[%collapsible%open]
======
`feature_name`:::
(string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name]
`importance`:::
(object)
A collection of feature importance statistics related to the training data set for this particular feature.
+
.Properties of feature importance
[%collapsible%open]
=======
`mean_magnitude`:::
(double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
`max`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
`min`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
=======
`classes`:::
(array)
If the trained model is a classification model, feature importance statistics are gathered
per target class value.
+
.Properties of class feature importance
[%collapsible%open]
=======
`class_name`:::
(string)
The target class value. Could be a string, boolean, or number.
`importance`:::
(object)
A collection of feature importance statistics related to the training data set for this particular feature.
+
.Properties of feature importance
[%collapsible%open]
========
`mean_magnitude`:::
(double)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
`max`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
`min`:::
(int)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
========
=======
======
=====
`model_id`:::
(string)
@ -152,13 +235,13 @@ The {es} version number in which the trained model was created.
== {api-response-codes-title}
`400`::
If `include_model_definition` is `true`, this code indicates that more than
If `include_model_definition` is `true`, this code indicates that more than
one models match the ID pattern.
`404` (Missing resources)::
If `allow_no_match` is `false`, this code indicates that there are no
resources that match the request or only partial matches for the request.
[[ml-get-inference-example]]
== {api-examples-title}

View File

@ -780,6 +780,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
used to train the model, which defaults to `<dependent_variable>_prediction`.
end::inference-config-results-field-processor[]
tag::inference-metadata-feature-importance-feature-name[]
The training feature name for which this importance was calculated.
end::inference-metadata-feature-importance-feature-name[]
tag::inference-metadata-feature-importance-magnitude[]
The average magnitude of this feature across all the training data.
This value is the average of the absolute values of the importance
for this feature.
end::inference-metadata-feature-importance-magnitude[]
tag::inference-metadata-feature-importance-max[]
The maximum importance value across all the training data for this
feature.
end::inference-metadata-feature-importance-max[]
tag::inference-metadata-feature-importance-min[]
The minimum importance value across all the training data for this
feature.
end::inference-metadata-feature-importance-min[]
tag::influencers[]
A comma separated list of influencer field names. Typically these can be the by,
over, or partition fields that are used in the detector configuration. You might

View File

@ -10,15 +10,19 @@ import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
@ -32,23 +36,60 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
public static class Request extends AbstractGetResourcesRequest {
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
static final String DEFINITION = "definition";
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
private static final Set<String> KNOWN_INCLUDES;
static {
HashSet<String> includes = new HashSet<>(2, 1.0f);
includes.add(DEFINITION);
includes.add(TOTAL_FEATURE_IMPORTANCE);
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
}
public static final ParseField INCLUDE = new ParseField("include");
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");
private final boolean includeModelDefinition;
private final Set<String> includes;
private final List<String> tags;
@Deprecated
public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id);
setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
this.tags = tags == null ? Collections.emptyList() : tags;
if (includeModelDefinition) {
this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
} else {
this.includes = Collections.emptySet();
}
}
public Request(String id, List<String> tags, Set<String> includes) {
setResourceId(id);
setAllowNoResources(true);
this.tags = tags == null ? Collections.emptyList() : tags;
this.includes = includes == null ? Collections.emptySet() : includes;
Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
if (unknownIncludes.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"unknown [include] parameters {}. Valid options are {}",
unknownIncludes,
KNOWN_INCLUDES);
}
}
public Request(StreamInput in) throws IOException {
super(in);
this.includeModelDefinition = in.readBoolean();
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
this.includes = in.readSet(StreamInput::readString);
} else {
Set<String> includes = new HashSet<>();
if (in.readBoolean()) {
includes.add(DEFINITION);
}
this.includes = includes;
}
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.tags = in.readStringList();
} else {
@ -62,7 +103,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
}
public boolean isIncludeModelDefinition() {
return includeModelDefinition;
return this.includes.contains(DEFINITION);
}
public boolean isIncludeTotalFeatureImportance() {
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
}
public List<String> getTags() {
@ -72,7 +117,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(includeModelDefinition);
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
out.writeCollection(this.includes, StreamOutput::writeString);
} else {
out.writeBoolean(this.includes.contains(DEFINITION));
}
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeStringCollection(tags);
}
@ -80,7 +129,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
return Objects.hash(super.hashCode(), includes, tags);
}
@Override
@ -92,7 +141,18 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
return false;
}
Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
return super.equals(obj) && Objects.equals(includes, other.includes) && Objects.equals(tags, other.tags);
}
@Override
public String toString() {
return "Request{" +
"includes=" + includes +
", tags=" + tags +
", page=" + getPageParams() +
", id=" + getResourceId() +
", allow_missing=" + isAllowNoResources() +
'}';
}
}

View File

@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
@ -39,6 +40,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static org.elasticsearch.action.ValidateActions.addValidationError;
@ -51,6 +53,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String FOR_EXPORT = "for_export";
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
@ -419,7 +423,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
this.description = config.getDescription();
this.tags = config.getTags();
this.metadata = config.getMetadata();
this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata());
this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
@ -471,6 +475,18 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setFeatureImportance(List<TotalFeatureImportance> totalFeatureImportance) {
if (totalFeatureImportance == null) {
return this;
}
if (this.metadata == null) {
this.metadata = new HashMap<>();
}
this.metadata.put(TOTAL_FEATURE_IMPORTANCE,
totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()));
return this;
}
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
if (definition == null) {
return this;
@ -627,6 +643,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
ESTIMATED_OPERATIONS.getPreferredName(),
validationException);
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
if (metadata != null) {
validationException = checkIllegalSetting(
metadata.get(TOTAL_FEATURE_IMPORTANCE),
METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE,
validationException);
}
}
if (validationException != null) {

View File

@ -20,8 +20,11 @@ import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class TotalFeatureImportance implements ToXContentObject, Writeable {
@ -81,16 +84,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
builder.field(IMPORTANCE.getPreferredName(), importance);
}
if (classImportances.isEmpty() == false) {
builder.field(CLASSES.getPreferredName(), classImportances);
}
builder.endObject();
return builder;
return builder.map(asMap());
}
@Override
@ -103,6 +97,18 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
&& Objects.equals(classImportances, that.classImportances);
}
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FEATURE_NAME.getPreferredName(), featureName);
if (importance != null) {
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
}
if (classImportances.isEmpty() == false) {
map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList()));
}
return map;
}
@Override
public int hashCode() {
return Objects.hash(featureName, importance, classImportances);
@ -165,12 +171,15 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
builder.field(MIN.getPreferredName(), min);
builder.field(MAX.getPreferredName(), max);
builder.endObject();
return builder;
return builder.map(asMap());
}
private Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
map.put(MIN.getPreferredName(), min);
map.put(MAX.getPreferredName(), max);
return map;
}
}
@ -229,11 +238,14 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASS_NAME.getPreferredName(), className);
builder.field(IMPORTANCE.getPreferredName(), importance);
builder.endObject();
return builder;
return builder.map(asMap());
}
private Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(CLASS_NAME.getPreferredName(), className);
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
return map;
}
@Override

View File

@ -53,6 +53,10 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
return NAME + "-" + modelId;
}
public static String modelId(String docId) {
return docId.substring(NAME.length() + 1);
}
private final List<TotalFeatureImportance> totalFeatureImportances;
private final String modelId;

View File

@ -103,7 +103,7 @@ public final class Messages {
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]";
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
public static final String INFERENCE_CANNOT_DELETE_MODEL =
"Unable to delete model [{0}]";
public static final String MODEL_DEFINITION_TRUNCATED =

View File

@ -5,19 +5,28 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20),
randomBoolean(),
randomBoolean() ? null :
randomList(10, () -> randomAlphaOfLength(10)));
randomList(10, () -> randomAlphaOfLength(10)),
randomBoolean() ? null :
Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
.limit(4)
.collect(Collectors.toSet()));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}
@ -26,4 +35,22 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
@Override
protected Request mutateInstanceForVersion(Request instance, Version version) {
if (version.before(Version.V_7_10_0)) {
Set<String> includes = new HashSet<>();
if (instance.isIncludeModelDefinition()) {
includes.add(Request.DEFINITION);
}
Request request = new Request(
instance.getResourceId(),
version.before(Version.V_7_7_0) ? null : instance.getTags(),
includes);
request.setPageParams(instance.getPageParams());
request.setAllowNoResources(instance.isAllowNoResources());
return request;
}
return instance;
}
}

View File

@ -42,11 +42,13 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.startsWith;
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
@ -95,19 +97,21 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L));
String inferenceModelId = ids.v2().iterator().next();
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
PlainActionFuture<TrainedModelMetadata> getTrainedMetadataFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture);
PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet();
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
assertThat(storedMetadata.getModelId(), startsWith(modelId));
assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
}

View File

@ -90,7 +90,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(config));
@ -121,7 +124,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get(), is(nullValue()));
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder);
blockingCall(listener ->
trainedModelProvider.getTrainedModel(modelId, false, false, listener),
getConfigHolder,
exceptionHolder);
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition));
@ -132,7 +138,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
String modelId = "test-get-missing-trained-model-config";
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
@ -154,7 +163,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
.actionGet();
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(),
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
@ -193,7 +205,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
}
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue()));
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
@ -238,7 +253,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
}
}
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
getConfigHolder,
exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue()));
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));

View File

@ -57,15 +57,25 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
}
if (request.isIncludeModelDefinition()) {
provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap(
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
listener::onFailure
));
provider.getTrainedModel(
totalAndIds.v2().iterator().next(),
true,
request.isIncludeTotalFeatureImportance(),
ActionListener.wrap(
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
listener::onFailure
)
);
} else {
provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap(
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
listener::onFailure
));
provider.getTrainedModels(
totalAndIds.v2(),
request.isAllowNoResources(),
request.isIncludeTotalFeatureImportance(),
ActionListener.wrap(
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
listener::onFailure
)
);
}
},
listener::onFailure

View File

@ -82,7 +82,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
responseBuilder.setLicensed(true);
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
} else {
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
trainedModelConfig -> {
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {

View File

@ -270,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener {
}
private void loadModel(String modelId, Consumer consumer) {
provider.getTrainedModel(modelId, false, ActionListener.wrap(
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
@ -306,7 +306,7 @@ public class ModelLoadingService implements ClusterStateListener {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
provider.getTrainedModel(modelId, false, ActionListener.wrap(
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
trainedModelConfig -> {
// Verify we can pull the model into memory without causing OOM
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
@ -434,7 +434,7 @@ public class ModelLoadingService implements ClusterStateListener {
logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
notification.getValue().model.getModelId()));
// If the model is no longer referenced, flush the stats to persist as soon as possible
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
} finally {

View File

@ -89,9 +89,11 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
@ -235,14 +237,14 @@ public class TrainedModelProvider {
));
}
public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) {
public void getTrainedModelMetadata(Collection<String> modelIds, ActionListener<Map<String, TrainedModelMetadata>> listener) {
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.boolQuery()
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
.filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
TrainedModelMetadata.NAME))))
.setSize(1)
.setSize(10_000)
// First find the latest index
.addSort("_index", SortOrder.DESC)
.request();
@ -250,18 +252,20 @@ public class TrainedModelProvider {
searchResponse -> {
if (searchResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
return;
}
List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(),
modelId,
this::parseMetadataLenientlyFromSource);
listener.onResponse(metadataList.get(0));
HashMap<String, TrainedModelMetadata> map = new HashMap<>();
for (SearchHit hit : searchResponse.getHits().getHits()) {
String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId()));
map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId));
}
listener.onResponse(map);
},
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
return;
}
listener.onFailure(e);
@ -371,7 +375,7 @@ public class TrainedModelProvider {
// TODO Change this when we get more than just langIdent stored
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
listener.onResponse(
InferenceDefinition.builder()
@ -434,18 +438,50 @@ public class TrainedModelProvider {
));
}
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
public void getTrainedModel(final String modelId,
final boolean includeDefinition,
final boolean includeTotalFeatureImportance,
final ActionListener<TrainedModelConfig> finalListener) {
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
return;
} catch (ElasticsearchException ex) {
listener.onFailure(ex);
finalListener.onFailure(ex);
return;
}
}
ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
modelBuilder -> {
if (includeTotalFeatureImportance == false) {
finalListener.onResponse(modelBuilder.build());
return;
}
this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap(
metadata -> {
TrainedModelMetadata modelMetadata = metadata.get(modelId);
if (modelMetadata != null) {
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
}
finalListener.onResponse(modelBuilder.build());
},
failure -> {
// total feature importance is not necessary for a model to be valid
// we shouldn't fail if it is not found
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
finalListener.onResponse(modelBuilder.build());
return;
}
finalListener.onFailure(failure);
}
));
},
finalListener::onFailure
);
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery()
.addIds(modelId));
@ -483,11 +519,11 @@ public class TrainedModelProvider {
try {
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
} catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException(
getTrainedModelListener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return;
} catch (Exception ex) {
listener.onFailure(ex);
getTrainedModelListener.onFailure(ex);
return;
}
@ -500,22 +536,22 @@ public class TrainedModelProvider {
String compressedString = getDefinitionFromDocs(docs, modelId);
builder.setDefinitionFromString(compressedString);
} catch (ElasticsearchException elasticsearchException) {
listener.onFailure(elasticsearchException);
getTrainedModelListener.onFailure(elasticsearchException);
return;
}
} catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException(
getTrainedModelListener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
} catch (Exception ex) {
listener.onFailure(ex);
getTrainedModelListener.onFailure(ex);
return;
}
}
listener.onResponse(builder.build());
getTrainedModelListener.onResponse(builder);
},
listener::onFailure
getTrainedModelListener::onFailure
);
executeAsyncWithOrigin(client,
@ -532,7 +568,10 @@ public class TrainedModelProvider {
* This does no expansion on the ids.
* It assumes that there are fewer than 10k.
*/
public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
public void getTrainedModels(Set<String> modelIds,
boolean allowNoResources,
boolean includeTotalFeatureImportance,
final ActionListener<List<TrainedModelConfig>> finalListener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
@ -541,23 +580,63 @@ public class TrainedModelProvider {
.setQuery(queryBuilder)
.setSize(modelIds.size())
.request();
List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
for(String modelId : modelsAsResource) {
try {
configs.add(loadModelFromResource(modelId, true));
} catch (ElasticsearchException ex) {
listener.onFailure(ex);
finalListener.onFailure(ex);
return;
}
}
if (modelsInIndex.isEmpty()) {
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
listener.onResponse(configs);
finalListener.onResponse(configs.stream()
.map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList()));
return;
}
ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
modelBuilders -> {
if (includeTotalFeatureImportance == false) {
finalListener.onResponse(modelBuilders.stream()
.map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList()));
return;
}
this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
metadata ->
finalListener.onResponse(modelBuilders.stream()
.map(builder -> {
TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
if (modelMetadata != null) {
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
}
return builder.build();
})
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList())),
failure -> {
// total feature importance is not necessary for a model to be valid
// we shouldn't fail if it is not found
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
finalListener.onResponse(modelBuilders.stream()
.map(TrainedModelConfig.Builder::build)
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
.collect(Collectors.toList()));
return;
}
finalListener.onFailure(failure);
}
));
},
finalListener::onFailure
);
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
searchResponse -> {
Set<String> observedIds = new HashSet<>(
@ -568,12 +647,12 @@ public class TrainedModelProvider {
try {
if (observedIds.contains(searchHit.getId()) == false) {
configs.add(
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId())
);
observedIds.add(searchHit.getId());
}
} catch (IOException ex) {
listener.onFailure(
getTrainedModelListener.onFailure(
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
return;
}
@ -583,14 +662,13 @@ public class TrainedModelProvider {
// Otherwise, treat it as if it was never expanded to begin with.
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
return;
}
// Ensure sorted even with the injection of locally resourced models
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
listener.onResponse(configs);
getTrainedModelListener.onResponse(configs);
},
listener::onFailure
getTrainedModelListener::onFailure
);
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
@ -639,7 +717,7 @@ public class TrainedModelProvider {
foundResourceIds = new HashSet<>();
for(String resourceId : matchedResourceIds) {
// Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) {
foundResourceIds.add(resourceId);
}
}
@ -833,7 +911,7 @@ public class TrainedModelProvider {
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
}
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) {
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
if (resource == null) {
logger.error("[{}] presumed stored as a resource but not found", modelId);
@ -848,7 +926,7 @@ public class TrainedModelProvider {
if (nullOutDefinition) {
builder.clearDefinition();
}
return builder.build();
return builder;
} catch (IOException ioEx) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);

View File

@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -56,12 +57,17 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
if (Strings.isNullOrEmpty(modelId)) {
modelId = Metadata.ALL;
}
boolean includeModelDefinition = restRequest.paramAsBoolean(
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
false
);
List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
Set<String> includes = new HashSet<>(
asList(
restRequest.paramAsStringArray(
GetTrainedModelsAction.Request.INCLUDE.getPreferredName(),
Strings.EMPTY_ARRAY)));
final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ?
new GetTrainedModelsAction.Request(modelId,
restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false),
tags) :
new GetTrainedModelsAction.Request(modelId, tags, includes);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));

View File

@ -437,9 +437,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
// the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three
assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
});
assertBusy(() -> {
assertThat(circuitBreaker.getUsed(), equalTo(10L));
@ -553,10 +553,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
}
@SuppressWarnings("unchecked")
@ -564,20 +564,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
if (randomBoolean()) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
} else {
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];

View File

@ -57,14 +57,14 @@ public class TrainedModelProviderTests extends ESTestCase {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, true, future);
trainedModelProvider.getTrainedModel(modelId, true, false, future);
TrainedModelConfig configWithDefinition = future.actionGet();
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));

View File

@ -33,7 +33,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
// Should be OK as we don't make any client calls
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
TrainedModelConfig config = future.actionGet();
config.ensureParsedDefinition(xContentRegistry());

View File

@ -34,11 +34,10 @@
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
"default":true
},
"include_model_definition":{
"type":"boolean",
"include":{
"type":"string",
"required":false,
"description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.",
"default":false
"description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none."
},
"decompress_definition":{
"type":"boolean",

View File

@ -1,6 +1,24 @@
setup:
- skip:
features: headers
features:
- headers
- allowed_warnings
- do:
allowed_warnings:
- "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template"
headers:
Content-Type: application/json
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_metadata-a-regression-model-0
index: .ml-inference-000003
body:
model_id: "a-regression-model-0"
doc_type: "trained_model_metadata"
total_feature_importance:
- { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
- { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
@ -548,6 +566,20 @@ setup:
- match: { count: 12 }
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
---
"Test get models with include total feature importance":
- do:
ml.get_trained_models:
model_id: "a-regression-model-*"
include: "total_feature_importance"
- match: { count: 2 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- is_true: trained_model_configs.0.metadata.total_feature_importance
- length: { trained_model_configs.0.metadata.total_feature_importance: 2 }
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
- is_false: trained_model_configs.1.metadata.total_feature_importance
---
"Test delete given unused trained model":
- do:
ml.delete_trained_model:
@ -824,7 +856,7 @@ setup:
ml.get_trained_models:
model_id: "a-regression-model-1"
for_export: true
include_model_definition: true
include: "definition"
decompress_definition: false
- match: { trained_model_configs.0.description: "empty model for tests" }