[7.x] [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) (#62620)
* [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) Adds new flag include to the get trained models API The flag initially has two valid values: definition, total_feature_importance. Consequently, the old include_model_definition flag is now deprecated. When total_feature_importance is included, the total_feature_importance field is included in the model metadata object. Including definition is the same as previously setting include_model_definition=true. * fixing test * Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
This commit is contained in:
parent
e1a4a3073a
commit
e163559e4c
|
@ -779,9 +779,9 @@ final class MLRequestConverters {
|
|||
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
|
||||
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
|
||||
}
|
||||
if (getTrainedModelsRequest.getIncludeDefinition() != null) {
|
||||
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
|
||||
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
|
||||
if (getTrainedModelsRequest.getIncludes().isEmpty() == false) {
|
||||
params.putParam(GetTrainedModelsRequest.INCLUDE,
|
||||
Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes()));
|
||||
}
|
||||
if (getTrainedModelsRequest.getTags() != null) {
|
||||
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
|
||||
|
|
|
@ -26,21 +26,26 @@ import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
public class GetTrainedModelsRequest implements Validatable {
|
||||
|
||||
private static final String DEFINITION = "definition";
|
||||
private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||
public static final String FOR_EXPORT = "for_export";
|
||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||
public static final String TAGS = "tags";
|
||||
public static final String INCLUDE = "include";
|
||||
|
||||
private final List<String> ids;
|
||||
private Boolean allowNoMatch;
|
||||
private Boolean includeDefinition;
|
||||
private Set<String> includes = new HashSet<>();
|
||||
private Boolean decompressDefinition;
|
||||
private Boolean forExport;
|
||||
private PageParams pageParams;
|
||||
|
@ -86,19 +91,32 @@ public class GetTrainedModelsRequest implements Validatable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Boolean getIncludeDefinition() {
|
||||
return includeDefinition;
|
||||
public Set<String> getIncludes() {
|
||||
return Collections.unmodifiableSet(includes);
|
||||
}
|
||||
|
||||
public GetTrainedModelsRequest includeDefinition() {
|
||||
this.includes.add(DEFINITION);
|
||||
return this;
|
||||
}
|
||||
|
||||
public GetTrainedModelsRequest includeTotalFeatureImportance() {
|
||||
this.includes.add(TOTAL_FEATURE_IMPORTANCE);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether to include the full model definition.
|
||||
*
|
||||
* The full model definition can be very large.
|
||||
*
|
||||
* @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()}
|
||||
* @param includeDefinition If {@code true}, the definition is included.
|
||||
*/
|
||||
@Deprecated
|
||||
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
|
||||
this.includeDefinition = includeDefinition;
|
||||
if (includeDefinition != null && includeDefinition) {
|
||||
return this.includeDefinition();
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -173,13 +191,13 @@ public class GetTrainedModelsRequest implements Validatable {
|
|||
return Objects.equals(ids, other.ids)
|
||||
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
||||
&& Objects.equals(decompressDefinition, other.decompressDefinition)
|
||||
&& Objects.equals(includeDefinition, other.includeDefinition)
|
||||
&& Objects.equals(includes, other.includes)
|
||||
&& Objects.equals(forExport, other.forExport)
|
||||
&& Objects.equals(pageParams, other.pageParams);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
|
||||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
||||
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class TotalFeatureImportance implements ToXContentObject {
|
||||
|
||||
private static final String NAME = "total_feature_importance";
|
||||
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||
public static final ParseField CLASSES = new ParseField("classes");
|
||||
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
|
||||
public static final ParseField MIN = new ParseField("min");
|
||||
public static final ParseField MAX = new ParseField("max");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
|
||||
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
|
||||
}
|
||||
|
||||
public static TotalFeatureImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public final String featureName;
|
||||
public final Importance importance;
|
||||
public final List<ClassImportance> classImportances;
|
||||
|
||||
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
|
||||
this.featureName = featureName;
|
||||
this.importance = importance;
|
||||
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
if (importance != null) {
|
||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||
}
|
||||
if (classImportances.isEmpty() == false) {
|
||||
builder.field(CLASSES.getPreferredName(), classImportances);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TotalFeatureImportance that = (TotalFeatureImportance) o;
|
||||
return Objects.equals(that.importance, importance)
|
||||
&& Objects.equals(featureName, that.featureName)
|
||||
&& Objects.equals(classImportances, that.classImportances);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportances);
|
||||
}
|
||||
|
||||
public static class Importance implements ToXContentObject {
|
||||
private static final String NAME = "importance";
|
||||
|
||||
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
|
||||
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
|
||||
}
|
||||
|
||||
private final double meanMagnitude;
|
||||
private final double min;
|
||||
private final double max;
|
||||
|
||||
public Importance(double meanMagnitude, double min, double max) {
|
||||
this.meanMagnitude = meanMagnitude;
|
||||
this.min = min;
|
||||
this.max = max;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Importance that = (Importance) o;
|
||||
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
|
||||
Double.compare(that.min, min) == 0 &&
|
||||
Double.compare(that.max, max) == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(meanMagnitude, min, max);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||
builder.field(MIN.getPreferredName(), min);
|
||||
builder.field(MAX.getPreferredName(), max);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
||||
public static class ClassImportance implements ToXContentObject {
|
||||
private static final String NAME = "total_class_importance";
|
||||
|
||||
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||
|
||||
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||
true,
|
||||
a -> new ClassImportance(a[0], (Importance)a[1]));
|
||||
|
||||
static {
|
||||
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||
return p.text();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||
return p.numberValue();
|
||||
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||
return p.booleanValue();
|
||||
}
|
||||
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
|
||||
}
|
||||
|
||||
public static ClassImportance fromXContent(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
|
||||
public final Object className;
|
||||
public final Importance importance;
|
||||
|
||||
ClassImportance(Object className, Importance importance) {
|
||||
this.className = className;
|
||||
this.importance = importance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ClassImportance that = (ClassImportance) o;
|
||||
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(className, importance);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -894,7 +894,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||
GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
|
||||
.setAllowNoMatch(false)
|
||||
.setDecompressDefinition(true)
|
||||
.setIncludeDefinition(false)
|
||||
.includeDefinition()
|
||||
.setTags("tag1", "tag2")
|
||||
.setPageParams(new PageParams(100, 300));
|
||||
|
||||
|
@ -908,7 +908,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||
hasEntry("allow_no_match", "false"),
|
||||
hasEntry("decompress_definition", "true"),
|
||||
hasEntry("tags", "tag1,tag2"),
|
||||
hasEntry("include_model_definition", "false")
|
||||
hasEntry("include", "definition")
|
||||
));
|
||||
assertNull(request.getEntity());
|
||||
}
|
||||
|
|
|
@ -2257,7 +2257,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
|
||||
{
|
||||
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
||||
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true),
|
||||
new GetTrainedModelsRequest(modelIdPrefix + 0)
|
||||
.setDecompressDefinition(true)
|
||||
.includeDefinition()
|
||||
.includeTotalFeatureImportance(),
|
||||
machineLearningClient::getTrainedModels,
|
||||
machineLearningClient::getTrainedModelsAsync);
|
||||
|
||||
|
@ -2268,7 +2271,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
|
||||
|
||||
getTrainedModelsResponse = execute(
|
||||
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true),
|
||||
new GetTrainedModelsRequest(modelIdPrefix + 0)
|
||||
.setDecompressDefinition(false)
|
||||
.includeTotalFeatureImportance()
|
||||
.includeDefinition(),
|
||||
machineLearningClient::getTrainedModels,
|
||||
machineLearningClient::getTrainedModelsAsync);
|
||||
|
||||
|
@ -2279,7 +2285,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
|
||||
|
||||
getTrainedModelsResponse = execute(
|
||||
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false),
|
||||
new GetTrainedModelsRequest(modelIdPrefix + 0)
|
||||
.setDecompressDefinition(false),
|
||||
machineLearningClient::getTrainedModels,
|
||||
machineLearningClient::getTrainedModelsAsync);
|
||||
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
|
||||
|
|
|
@ -3694,11 +3694,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
// tag::get-trained-models-request
|
||||
GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
|
||||
.setPageParams(new PageParams(0, 1)) // <2>
|
||||
.setIncludeDefinition(false) // <3>
|
||||
.setDecompressDefinition(false) // <4>
|
||||
.setAllowNoMatch(true) // <5>
|
||||
.setTags("regression") // <6>
|
||||
.setForExport(false); // <7>
|
||||
.includeDefinition() // <3>
|
||||
.includeTotalFeatureImportance() // <4>
|
||||
.setDecompressDefinition(false) // <5>
|
||||
.setAllowNoMatch(true) // <6>
|
||||
.setTags("regression") // <7>
|
||||
.setForExport(false); // <8>
|
||||
// end::get-trained-models-request
|
||||
request.setTags((List<String>)null);
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
||||
|
||||
import org.elasticsearch.common.xcontent.XContentParser;
|
||||
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
|
||||
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
|
||||
|
||||
|
||||
public static TotalFeatureImportance randomInstance() {
|
||||
return new TotalFeatureImportance(
|
||||
randomAlphaOfLength(10),
|
||||
randomBoolean() ? null : randomImportance(),
|
||||
randomBoolean() ?
|
||||
null :
|
||||
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
|
||||
.limit(randomIntBetween(1, 10))
|
||||
.collect(Collectors.toList())
|
||||
);
|
||||
}
|
||||
|
||||
private static TotalFeatureImportance.Importance randomImportance() {
|
||||
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TotalFeatureImportance createTestInstance() {
|
||||
return randomInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||
return TotalFeatureImportance.fromXContent(parser);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean supportsUnknownFields() {
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
|
@ -22,26 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models.
|
|||
--------------------------------------------------
|
||||
include-tagged::{doc-tests-file}[{api}-request]
|
||||
--------------------------------------------------
|
||||
<1> Constructing a new GET request referencing an existing Trained Model
|
||||
<1> Constructing a new GET request referencing an existing trained model
|
||||
<2> Set the paging parameters
|
||||
<3> Indicate if the complete model definition should be included
|
||||
<4> Should the definition be fully decompressed on GET
|
||||
<5> Allow empty response if no Trained Models match the provided ID patterns.
|
||||
If false, an error will be thrown if no Trained Models match the
|
||||
<4> Indicate if the total feature importance for the features used in training
|
||||
should be included in the model `metadata` field.
|
||||
<5> Should the definition be fully decompressed on GET
|
||||
<6> Allow empty response if no trained models match the provided ID patterns.
|
||||
If false, an error will be thrown if no trained models match the
|
||||
ID patterns.
|
||||
<6> An optional list of tags used to narrow the model search. A Trained Model
|
||||
<7> An optional list of tags used to narrow the model search. A trained model
|
||||
can have many tags or none. The trained models in the response will
|
||||
contain all the provided tags.
|
||||
<7> Optional boolean value indicating if certain fields should be removed on
|
||||
retrieval. This is useful for getting the trained model in a format that
|
||||
can then be put into another cluster.
|
||||
<8> Optional boolean value for requesting the trained model in a format that can
|
||||
then be put into another cluster. Certain fields that can only be set when
|
||||
the model is imported are removed.
|
||||
|
||||
include::../execution.asciidoc[]
|
||||
|
||||
[id="{upid}-{api}-response"]
|
||||
==== Response
|
||||
|
||||
The returned +{response}+ contains the requested Trained Model.
|
||||
The returned +{response}+ contains the requested trained model.
|
||||
|
||||
["source","java",subs="attributes,callouts,macros"]
|
||||
--------------------------------------------------
|
||||
|
|
|
@ -29,18 +29,19 @@ experimental[]
|
|||
[[ml-get-inference-prereq]]
|
||||
== {api-prereq-title}
|
||||
|
||||
Required privileges which should be added to a custom role:
|
||||
If the {es} {security-features} are enabled, you must have the following
|
||||
privileges:
|
||||
|
||||
* cluster: `monitor_ml`
|
||||
|
||||
For more information, see <<security-privileges>> and
|
||||
|
||||
For more information, see <<security-privileges>> and
|
||||
{ml-docs-setup-privileges}.
|
||||
|
||||
|
||||
[[ml-get-inference-desc]]
|
||||
== {api-description-title}
|
||||
|
||||
You can get information for multiple trained models in a single API request by
|
||||
You can get information for multiple trained models in a single API request by
|
||||
using a comma-separated list of model IDs or a wildcard expression.
|
||||
|
||||
|
||||
|
@ -48,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression.
|
|||
== {api-path-parms-title}
|
||||
|
||||
`<model_id>`::
|
||||
(Optional, string)
|
||||
(Optional, string)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
||||
|
||||
|
||||
|
@ -56,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
|||
== {api-query-parms-title}
|
||||
|
||||
`allow_no_match`::
|
||||
(Optional, boolean)
|
||||
(Optional, boolean)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models]
|
||||
|
||||
`decompress_definition`::
|
||||
(Optional, boolean)
|
||||
Specifies whether the included model definition should be returned as a JSON map
|
||||
Specifies whether the included model definition should be returned as a JSON map
|
||||
(`true`) or in a custom compressed format (`false`). Defaults to `true`.
|
||||
|
||||
`for_export`::
|
||||
|
@ -71,17 +72,21 @@ retrieval. This allows the model to be in an acceptable format to be retrieved
|
|||
and then added to another cluster. Default is false.
|
||||
|
||||
`from`::
|
||||
(Optional, integer)
|
||||
(Optional, integer)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models]
|
||||
|
||||
`include_model_definition`::
|
||||
(Optional, boolean)
|
||||
Specifies whether the model definition is returned in the response. Defaults to
|
||||
`false`. When `true`, only a single model must match the ID patterns provided.
|
||||
Otherwise, a bad request is returned.
|
||||
`include`::
|
||||
(Optional, string)
|
||||
A comma delimited string of optional fields to include in the response body.
|
||||
Valid options are:
|
||||
- `definition`: Includes the model definition
|
||||
- `total_feature_importance`: Includes the total feature importance for the
|
||||
training data set. This field is available in the `metadata` field in the
|
||||
response body.
|
||||
Default is empty, indicating including no optional fields.
|
||||
|
||||
`size`::
|
||||
(Optional, integer)
|
||||
(Optional, integer)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models]
|
||||
|
||||
`tags`::
|
||||
|
@ -94,7 +99,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags]
|
|||
|
||||
`trained_model_configs`::
|
||||
(array)
|
||||
An array of trained model resources, which are sorted by the `model_id` value in
|
||||
An array of trained model resources, which are sorted by the `model_id` value in
|
||||
ascending order.
|
||||
+
|
||||
.Properties of trained model resources
|
||||
|
@ -132,8 +137,86 @@ The license level of the trained model.
|
|||
|
||||
`metadata`:::
|
||||
(object)
|
||||
An object containing metadata about the trained model. For example, models
|
||||
An object containing metadata about the trained model. For example, models
|
||||
created by {dfanalytics} contain `analysis_config` and `input` objects.
|
||||
.Properties of metadata
|
||||
[%collapsible%open]
|
||||
=====
|
||||
`total_feature_importance`:::
|
||||
(array)
|
||||
An array of the total feature importance for each feature used from
|
||||
the training data set. This array of objects is returned if {dfanalytics} trained
|
||||
the model and the request includes `total_feature_importance` in the `include`
|
||||
request parameter.
|
||||
+
|
||||
.Properties of total feature importance
|
||||
[%collapsible%open]
|
||||
======
|
||||
|
||||
`feature_name`:::
|
||||
(string)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name]
|
||||
|
||||
`importance`:::
|
||||
(object)
|
||||
A collection of feature importance statistics related to the training data set for this particular feature.
|
||||
+
|
||||
.Properties of feature importance
|
||||
[%collapsible%open]
|
||||
=======
|
||||
`mean_magnitude`:::
|
||||
(double)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
|
||||
|
||||
`max`:::
|
||||
(int)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
|
||||
|
||||
`min`:::
|
||||
(int)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
|
||||
|
||||
=======
|
||||
|
||||
`classes`:::
|
||||
(array)
|
||||
If the trained model is a classification model, feature importance statistics are gathered
|
||||
per target class value.
|
||||
+
|
||||
.Properties of class feature importance
|
||||
[%collapsible%open]
|
||||
|
||||
=======
|
||||
|
||||
`class_name`:::
|
||||
(string)
|
||||
The target class value. Could be a string, boolean, or number.
|
||||
|
||||
`importance`:::
|
||||
(object)
|
||||
A collection of feature importance statistics related to the training data set for this particular feature.
|
||||
+
|
||||
.Properties of feature importance
|
||||
[%collapsible%open]
|
||||
========
|
||||
`mean_magnitude`:::
|
||||
(double)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
|
||||
|
||||
`max`:::
|
||||
(int)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
|
||||
|
||||
`min`:::
|
||||
(int)
|
||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
|
||||
|
||||
========
|
||||
|
||||
=======
|
||||
|
||||
======
|
||||
=====
|
||||
|
||||
`model_id`:::
|
||||
(string)
|
||||
|
@ -152,13 +235,13 @@ The {es} version number in which the trained model was created.
|
|||
== {api-response-codes-title}
|
||||
|
||||
`400`::
|
||||
If `include_model_definition` is `true`, this code indicates that more than
|
||||
If `include_model_definition` is `true`, this code indicates that more than
|
||||
one models match the ID pattern.
|
||||
|
||||
`404` (Missing resources)::
|
||||
If `allow_no_match` is `false`, this code indicates that there are no
|
||||
resources that match the request or only partial matches for the request.
|
||||
|
||||
|
||||
|
||||
[[ml-get-inference-example]]
|
||||
== {api-examples-title}
|
||||
|
|
|
@ -780,6 +780,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
|
|||
used to train the model, which defaults to `<dependent_variable>_prediction`.
|
||||
end::inference-config-results-field-processor[]
|
||||
|
||||
tag::inference-metadata-feature-importance-feature-name[]
|
||||
The training feature name for which this importance was calculated.
|
||||
end::inference-metadata-feature-importance-feature-name[]
|
||||
tag::inference-metadata-feature-importance-magnitude[]
|
||||
The average magnitude of this feature across all the training data.
|
||||
This value is the average of the absolute values of the importance
|
||||
for this feature.
|
||||
end::inference-metadata-feature-importance-magnitude[]
|
||||
tag::inference-metadata-feature-importance-max[]
|
||||
The maximum importance value across all the training data for this
|
||||
feature.
|
||||
end::inference-metadata-feature-importance-max[]
|
||||
tag::inference-metadata-feature-importance-min[]
|
||||
The minimum importance value across all the training data for this
|
||||
feature.
|
||||
end::inference-metadata-feature-importance-min[]
|
||||
|
||||
tag::influencers[]
|
||||
A comma separated list of influencer field names. Typically these can be the by,
|
||||
over, or partition fields that are used in the detector configuration. You might
|
||||
|
|
|
@ -10,15 +10,19 @@ import org.elasticsearch.action.ActionType;
|
|||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
|
||||
|
@ -32,23 +36,60 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
|
||||
public static class Request extends AbstractGetResourcesRequest {
|
||||
|
||||
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
|
||||
static final String DEFINITION = "definition";
|
||||
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||
private static final Set<String> KNOWN_INCLUDES;
|
||||
static {
|
||||
HashSet<String> includes = new HashSet<>(2, 1.0f);
|
||||
includes.add(DEFINITION);
|
||||
includes.add(TOTAL_FEATURE_IMPORTANCE);
|
||||
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
|
||||
}
|
||||
public static final ParseField INCLUDE = new ParseField("include");
|
||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
||||
public static final ParseField TAGS = new ParseField("tags");
|
||||
|
||||
private final boolean includeModelDefinition;
|
||||
private final Set<String> includes;
|
||||
private final List<String> tags;
|
||||
|
||||
@Deprecated
|
||||
public Request(String id, boolean includeModelDefinition, List<String> tags) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.includeModelDefinition = includeModelDefinition;
|
||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||
if (includeModelDefinition) {
|
||||
this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
|
||||
} else {
|
||||
this.includes = Collections.emptySet();
|
||||
}
|
||||
}
|
||||
|
||||
public Request(String id, List<String> tags, Set<String> includes) {
|
||||
setResourceId(id);
|
||||
setAllowNoResources(true);
|
||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||
this.includes = includes == null ? Collections.emptySet() : includes;
|
||||
Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
|
||||
if (unknownIncludes.isEmpty() == false) {
|
||||
throw ExceptionsHelper.badRequestException(
|
||||
"unknown [include] parameters {}. Valid options are {}",
|
||||
unknownIncludes,
|
||||
KNOWN_INCLUDES);
|
||||
}
|
||||
}
|
||||
|
||||
public Request(StreamInput in) throws IOException {
|
||||
super(in);
|
||||
this.includeModelDefinition = in.readBoolean();
|
||||
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
this.includes = in.readSet(StreamInput::readString);
|
||||
} else {
|
||||
Set<String> includes = new HashSet<>();
|
||||
if (in.readBoolean()) {
|
||||
includes.add(DEFINITION);
|
||||
}
|
||||
this.includes = includes;
|
||||
}
|
||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
this.tags = in.readStringList();
|
||||
} else {
|
||||
|
@ -62,7 +103,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
}
|
||||
|
||||
public boolean isIncludeModelDefinition() {
|
||||
return includeModelDefinition;
|
||||
return this.includes.contains(DEFINITION);
|
||||
}
|
||||
|
||||
public boolean isIncludeTotalFeatureImportance() {
|
||||
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
|
||||
}
|
||||
|
||||
public List<String> getTags() {
|
||||
|
@ -72,7 +117,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
super.writeTo(out);
|
||||
out.writeBoolean(includeModelDefinition);
|
||||
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||
out.writeCollection(this.includes, StreamOutput::writeString);
|
||||
} else {
|
||||
out.writeBoolean(this.includes.contains(DEFINITION));
|
||||
}
|
||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||
out.writeStringCollection(tags);
|
||||
}
|
||||
|
@ -80,7 +129,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
|
||||
return Objects.hash(super.hashCode(), includes, tags);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -92,7 +141,18 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
|||
return false;
|
||||
}
|
||||
Request other = (Request) obj;
|
||||
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
|
||||
return super.equals(obj) && Objects.equals(includes, other.includes) && Objects.equals(tags, other.tags);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Request{" +
|
||||
"includes=" + includes +
|
||||
", tags=" + tags +
|
||||
", page=" + getPageParams() +
|
||||
", id=" + getResourceId() +
|
||||
", allow_missing=" + isAllowNoResources() +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||
|
@ -39,6 +40,7 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
||||
|
@ -51,6 +53,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
|
||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||
public static final String FOR_EXPORT = "for_export";
|
||||
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||
private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
|
||||
|
||||
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
||||
|
||||
|
@ -419,7 +423,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
|
||||
this.description = config.getDescription();
|
||||
this.tags = config.getTags();
|
||||
this.metadata = config.getMetadata();
|
||||
this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata());
|
||||
this.input = config.getInput();
|
||||
this.estimatedOperations = config.estimatedOperations;
|
||||
this.estimatedHeapMemory = config.estimatedHeapMemory;
|
||||
|
@ -471,6 +475,18 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setFeatureImportance(List<TotalFeatureImportance> totalFeatureImportance) {
|
||||
if (totalFeatureImportance == null) {
|
||||
return this;
|
||||
}
|
||||
if (this.metadata == null) {
|
||||
this.metadata = new HashMap<>();
|
||||
}
|
||||
this.metadata.put(TOTAL_FEATURE_IMPORTANCE,
|
||||
totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()));
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
|
||||
if (definition == null) {
|
||||
return this;
|
||||
|
@ -627,6 +643,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
ESTIMATED_OPERATIONS.getPreferredName(),
|
||||
validationException);
|
||||
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
|
||||
if (metadata != null) {
|
||||
validationException = checkIllegalSetting(
|
||||
metadata.get(TOTAL_FEATURE_IMPORTANCE),
|
||||
METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE,
|
||||
validationException);
|
||||
}
|
||||
}
|
||||
|
||||
if (validationException != null) {
|
||||
|
|
|
@ -20,8 +20,11 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||
|
||||
|
@ -81,16 +84,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||
if (importance != null) {
|
||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||
}
|
||||
if (classImportances.isEmpty() == false) {
|
||||
builder.field(CLASSES.getPreferredName(), classImportances);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
return builder.map(asMap());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -103,6 +97,18 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
&& Objects.equals(classImportances, that.classImportances);
|
||||
}
|
||||
|
||||
public Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(FEATURE_NAME.getPreferredName(), featureName);
|
||||
if (importance != null) {
|
||||
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
|
||||
}
|
||||
if (classImportances.isEmpty() == false) {
|
||||
map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList()));
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(featureName, importance, classImportances);
|
||||
|
@ -165,12 +171,15 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||
builder.field(MIN.getPreferredName(), min);
|
||||
builder.field(MAX.getPreferredName(), max);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
return builder.map(asMap());
|
||||
}
|
||||
|
||||
private Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||
map.put(MIN.getPreferredName(), min);
|
||||
map.put(MAX.getPreferredName(), max);
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -229,11 +238,14 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
|||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
return builder.map(asMap());
|
||||
}
|
||||
|
||||
private Map<String, Object> asMap() {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put(CLASS_NAME.getPreferredName(), className);
|
||||
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,6 +53,10 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
|||
return NAME + "-" + modelId;
|
||||
}
|
||||
|
||||
public static String modelId(String docId) {
|
||||
return docId.substring(NAME.length() + 1);
|
||||
}
|
||||
|
||||
private final List<TotalFeatureImportance> totalFeatureImportances;
|
||||
private final String modelId;
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ public final class Messages {
|
|||
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
|
||||
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
|
||||
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
||||
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]";
|
||||
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
|
||||
public static final String INFERENCE_CANNOT_DELETE_MODEL =
|
||||
"Unable to delete model [{0}]";
|
||||
public static final String MODEL_DEFINITION_TRUNCATED =
|
||||
|
|
|
@ -5,19 +5,28 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.core.ml.action;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
|
||||
|
||||
public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
|
||||
|
||||
@Override
|
||||
protected Request createTestInstance() {
|
||||
Request request = new Request(randomAlphaOfLength(20),
|
||||
randomBoolean(),
|
||||
randomBoolean() ? null :
|
||||
randomList(10, () -> randomAlphaOfLength(10)));
|
||||
randomList(10, () -> randomAlphaOfLength(10)),
|
||||
randomBoolean() ? null :
|
||||
Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
|
||||
.limit(4)
|
||||
.collect(Collectors.toSet()));
|
||||
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
|
||||
return request;
|
||||
}
|
||||
|
@ -26,4 +35,22 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
|
|||
protected Writeable.Reader<Request> instanceReader() {
|
||||
return Request::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Request mutateInstanceForVersion(Request instance, Version version) {
|
||||
if (version.before(Version.V_7_10_0)) {
|
||||
Set<String> includes = new HashSet<>();
|
||||
if (instance.isIncludeModelDefinition()) {
|
||||
includes.add(Request.DEFINITION);
|
||||
}
|
||||
Request request = new Request(
|
||||
instance.getResourceId(),
|
||||
version.before(Version.V_7_7_0) ? null : instance.getTags(),
|
||||
includes);
|
||||
request.setPageParams(instance.getPageParams());
|
||||
request.setAllowNoResources(instance.isAllowNoResources());
|
||||
return request;
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,11 +42,13 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
||||
|
@ -95,19 +97,21 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
|||
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
|
||||
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
|
||||
assertThat(ids.v1(), equalTo(1L));
|
||||
String inferenceModelId = ids.v2().iterator().next();
|
||||
|
||||
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
|
||||
trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
|
||||
|
||||
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
|
||||
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
|
||||
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
|
||||
|
||||
PlainActionFuture<TrainedModelMetadata> getTrainedMetadataFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture);
|
||||
PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
|
||||
|
||||
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet();
|
||||
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
|
||||
assertThat(storedMetadata.getModelId(), startsWith(modelId));
|
||||
assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
|
||||
}
|
||||
|
|
|
@ -90,7 +90,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
||||
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||
assertThat(getConfigHolder.get(), equalTo(config));
|
||||
|
@ -121,7 +124,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(listener ->
|
||||
trainedModelProvider.getTrainedModel(modelId, false, false, listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
||||
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||
assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition));
|
||||
|
@ -132,7 +138,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
String modelId = "test-get-missing-trained-model-config";
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(),
|
||||
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
|
@ -154,7 +163,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
.actionGet();
|
||||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(),
|
||||
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
|
@ -193,7 +205,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
}
|
||||
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||
|
@ -238,7 +253,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
}
|
||||
}
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||
blockingCall(
|
||||
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||
getConfigHolder,
|
||||
exceptionHolder);
|
||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||
|
|
|
@ -57,15 +57,25 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
|
|||
}
|
||||
|
||||
if (request.isIncludeModelDefinition()) {
|
||||
provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap(
|
||||
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
|
||||
listener::onFailure
|
||||
));
|
||||
provider.getTrainedModel(
|
||||
totalAndIds.v2().iterator().next(),
|
||||
true,
|
||||
request.isIncludeTotalFeatureImportance(),
|
||||
ActionListener.wrap(
|
||||
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
|
||||
listener::onFailure
|
||||
)
|
||||
);
|
||||
} else {
|
||||
provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap(
|
||||
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
|
||||
listener::onFailure
|
||||
));
|
||||
provider.getTrainedModels(
|
||||
totalAndIds.v2(),
|
||||
request.isAllowNoResources(),
|
||||
request.isIncludeTotalFeatureImportance(),
|
||||
ActionListener.wrap(
|
||||
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
|
||||
listener::onFailure
|
||||
)
|
||||
);
|
||||
}
|
||||
},
|
||||
listener::onFailure
|
||||
|
|
|
@ -82,7 +82,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|||
responseBuilder.setLicensed(true);
|
||||
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
||||
} else {
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
|
||||
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
|
||||
|
|
|
@ -270,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
|
||||
private void loadModel(String modelId, Consumer consumer) {
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
|
@ -306,7 +306,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
// Verify we can pull the model into memory without causing OOM
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
|
@ -434,7 +434,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
|
||||
logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
|
||||
notification.getValue().model.getModelId()));
|
||||
|
||||
|
||||
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
||||
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||
} finally {
|
||||
|
|
|
@ -89,9 +89,11 @@ import java.util.Arrays;
|
|||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -235,14 +237,14 @@ public class TrainedModelProvider {
|
|||
));
|
||||
}
|
||||
|
||||
public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) {
|
||||
public void getTrainedModelMetadata(Collection<String> modelIds, ActionListener<Map<String, TrainedModelMetadata>> listener) {
|
||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
||||
.boolQuery()
|
||||
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
|
||||
.filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
|
||||
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
|
||||
TrainedModelMetadata.NAME))))
|
||||
.setSize(1)
|
||||
.setSize(10_000)
|
||||
// First find the latest index
|
||||
.addSort("_index", SortOrder.DESC)
|
||||
.request();
|
||||
|
@ -250,18 +252,20 @@ public class TrainedModelProvider {
|
|||
searchResponse -> {
|
||||
if (searchResponse.getHits().getHits().length == 0) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
|
||||
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
|
||||
return;
|
||||
}
|
||||
List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(),
|
||||
modelId,
|
||||
this::parseMetadataLenientlyFromSource);
|
||||
listener.onResponse(metadataList.get(0));
|
||||
HashMap<String, TrainedModelMetadata> map = new HashMap<>();
|
||||
for (SearchHit hit : searchResponse.getHits().getHits()) {
|
||||
String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId()));
|
||||
map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId));
|
||||
}
|
||||
listener.onResponse(map);
|
||||
},
|
||||
e -> {
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
|
||||
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
|
||||
return;
|
||||
}
|
||||
listener.onFailure(e);
|
||||
|
@ -371,7 +375,7 @@ public class TrainedModelProvider {
|
|||
// TODO Change this when we get more than just langIdent stored
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||
try {
|
||||
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
|
||||
TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
|
||||
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
|
||||
listener.onResponse(
|
||||
InferenceDefinition.builder()
|
||||
|
@ -434,18 +438,50 @@ public class TrainedModelProvider {
|
|||
));
|
||||
}
|
||||
|
||||
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
||||
public void getTrainedModel(final String modelId,
|
||||
final boolean includeDefinition,
|
||||
final boolean includeTotalFeatureImportance,
|
||||
final ActionListener<TrainedModelConfig> finalListener) {
|
||||
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||
try {
|
||||
listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
|
||||
finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
|
||||
return;
|
||||
} catch (ElasticsearchException ex) {
|
||||
listener.onFailure(ex);
|
||||
finalListener.onFailure(ex);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
|
||||
modelBuilder -> {
|
||||
if (includeTotalFeatureImportance == false) {
|
||||
finalListener.onResponse(modelBuilder.build());
|
||||
return;
|
||||
}
|
||||
this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap(
|
||||
metadata -> {
|
||||
TrainedModelMetadata modelMetadata = metadata.get(modelId);
|
||||
if (modelMetadata != null) {
|
||||
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
||||
}
|
||||
finalListener.onResponse(modelBuilder.build());
|
||||
},
|
||||
failure -> {
|
||||
// total feature importance is not necessary for a model to be valid
|
||||
// we shouldn't fail if it is not found
|
||||
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
||||
finalListener.onResponse(modelBuilder.build());
|
||||
return;
|
||||
}
|
||||
finalListener.onFailure(failure);
|
||||
}
|
||||
));
|
||||
|
||||
},
|
||||
finalListener::onFailure
|
||||
);
|
||||
|
||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
||||
.idsQuery()
|
||||
.addIds(modelId));
|
||||
|
@ -483,11 +519,11 @@ public class TrainedModelProvider {
|
|||
try {
|
||||
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
|
||||
} catch (ResourceNotFoundException ex) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
getTrainedModelListener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return;
|
||||
} catch (Exception ex) {
|
||||
listener.onFailure(ex);
|
||||
getTrainedModelListener.onFailure(ex);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -500,22 +536,22 @@ public class TrainedModelProvider {
|
|||
String compressedString = getDefinitionFromDocs(docs, modelId);
|
||||
builder.setDefinitionFromString(compressedString);
|
||||
} catch (ElasticsearchException elasticsearchException) {
|
||||
listener.onFailure(elasticsearchException);
|
||||
getTrainedModelListener.onFailure(elasticsearchException);
|
||||
return;
|
||||
}
|
||||
|
||||
} catch (ResourceNotFoundException ex) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
getTrainedModelListener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
return;
|
||||
} catch (Exception ex) {
|
||||
listener.onFailure(ex);
|
||||
getTrainedModelListener.onFailure(ex);
|
||||
return;
|
||||
}
|
||||
}
|
||||
listener.onResponse(builder.build());
|
||||
getTrainedModelListener.onResponse(builder);
|
||||
},
|
||||
listener::onFailure
|
||||
getTrainedModelListener::onFailure
|
||||
);
|
||||
|
||||
executeAsyncWithOrigin(client,
|
||||
|
@ -532,7 +568,10 @@ public class TrainedModelProvider {
|
|||
* This does no expansion on the ids.
|
||||
* It assumes that there are fewer than 10k.
|
||||
*/
|
||||
public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
|
||||
public void getTrainedModels(Set<String> modelIds,
|
||||
boolean allowNoResources,
|
||||
boolean includeTotalFeatureImportance,
|
||||
final ActionListener<List<TrainedModelConfig>> finalListener) {
|
||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
|
||||
|
||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
|
@ -541,23 +580,63 @@ public class TrainedModelProvider {
|
|||
.setQuery(queryBuilder)
|
||||
.setSize(modelIds.size())
|
||||
.request();
|
||||
List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
|
||||
List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
|
||||
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
|
||||
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
|
||||
for(String modelId : modelsAsResource) {
|
||||
try {
|
||||
configs.add(loadModelFromResource(modelId, true));
|
||||
} catch (ElasticsearchException ex) {
|
||||
listener.onFailure(ex);
|
||||
finalListener.onFailure(ex);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (modelsInIndex.isEmpty()) {
|
||||
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
||||
listener.onResponse(configs);
|
||||
finalListener.onResponse(configs.stream()
|
||||
.map(TrainedModelConfig.Builder::build)
|
||||
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||
.collect(Collectors.toList()));
|
||||
return;
|
||||
}
|
||||
|
||||
ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
|
||||
modelBuilders -> {
|
||||
if (includeTotalFeatureImportance == false) {
|
||||
finalListener.onResponse(modelBuilders.stream()
|
||||
.map(TrainedModelConfig.Builder::build)
|
||||
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||
.collect(Collectors.toList()));
|
||||
return;
|
||||
}
|
||||
this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
|
||||
metadata ->
|
||||
finalListener.onResponse(modelBuilders.stream()
|
||||
.map(builder -> {
|
||||
TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
|
||||
if (modelMetadata != null) {
|
||||
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
||||
}
|
||||
return builder.build();
|
||||
})
|
||||
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||
.collect(Collectors.toList())),
|
||||
failure -> {
|
||||
// total feature importance is not necessary for a model to be valid
|
||||
// we shouldn't fail if it is not found
|
||||
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
||||
finalListener.onResponse(modelBuilders.stream()
|
||||
.map(TrainedModelConfig.Builder::build)
|
||||
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||
.collect(Collectors.toList()));
|
||||
return;
|
||||
}
|
||||
finalListener.onFailure(failure);
|
||||
}
|
||||
));
|
||||
},
|
||||
finalListener::onFailure
|
||||
);
|
||||
|
||||
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
Set<String> observedIds = new HashSet<>(
|
||||
|
@ -568,12 +647,12 @@ public class TrainedModelProvider {
|
|||
try {
|
||||
if (observedIds.contains(searchHit.getId()) == false) {
|
||||
configs.add(
|
||||
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
|
||||
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId())
|
||||
);
|
||||
observedIds.add(searchHit.getId());
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
listener.onFailure(
|
||||
getTrainedModelListener.onFailure(
|
||||
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
|
||||
return;
|
||||
}
|
||||
|
@ -583,14 +662,13 @@ public class TrainedModelProvider {
|
|||
// Otherwise, treat it as if it was never expanded to begin with.
|
||||
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
|
||||
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
|
||||
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
||||
getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
||||
return;
|
||||
}
|
||||
// Ensure sorted even with the injection of locally resourced models
|
||||
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
||||
listener.onResponse(configs);
|
||||
getTrainedModelListener.onResponse(configs);
|
||||
},
|
||||
listener::onFailure
|
||||
getTrainedModelListener::onFailure
|
||||
);
|
||||
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
|
||||
|
@ -639,7 +717,7 @@ public class TrainedModelProvider {
|
|||
foundResourceIds = new HashSet<>();
|
||||
for(String resourceId : matchedResourceIds) {
|
||||
// Does the model as a resource have all the tags?
|
||||
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
||||
if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) {
|
||||
foundResourceIds.add(resourceId);
|
||||
}
|
||||
}
|
||||
|
@ -833,7 +911,7 @@ public class TrainedModelProvider {
|
|||
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
|
||||
}
|
||||
|
||||
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
||||
TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
||||
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
|
||||
if (resource == null) {
|
||||
logger.error("[{}] presumed stored as a resource but not found", modelId);
|
||||
|
@ -848,7 +926,7 @@ public class TrainedModelProvider {
|
|||
if (nullOutDefinition) {
|
||||
builder.clearDefinition();
|
||||
}
|
||||
return builder.build();
|
||||
return builder;
|
||||
} catch (IOException ioEx) {
|
||||
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
|
||||
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
|
|||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
@ -56,12 +57,17 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
|||
if (Strings.isNullOrEmpty(modelId)) {
|
||||
modelId = Metadata.ALL;
|
||||
}
|
||||
boolean includeModelDefinition = restRequest.paramAsBoolean(
|
||||
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
|
||||
false
|
||||
);
|
||||
List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
|
||||
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
|
||||
Set<String> includes = new HashSet<>(
|
||||
asList(
|
||||
restRequest.paramAsStringArray(
|
||||
GetTrainedModelsAction.Request.INCLUDE.getPreferredName(),
|
||||
Strings.EMPTY_ARRAY)));
|
||||
final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ?
|
||||
new GetTrainedModelsAction.Request(modelId,
|
||||
restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false),
|
||||
tags) :
|
||||
new GetTrainedModelsAction.Request(modelId, tags, includes);
|
||||
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
|
||||
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
|
||||
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
||||
|
|
|
@ -437,9 +437,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
// the loading occurred or which models are currently in the cache due to evictions.
|
||||
// Verify that we have at least loaded all three
|
||||
assertBusy(() -> {
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
|
||||
});
|
||||
assertBusy(() -> {
|
||||
assertThat(circuitBreaker.getUsed(), equalTo(10L));
|
||||
|
@ -553,10 +553,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -564,20 +564,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
if (randomBoolean()) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||
} else {
|
||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
||||
|
|
|
@ -57,14 +57,14 @@ public class TrainedModelProviderTests extends ESTestCase {
|
|||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
|
||||
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(modelId, true, future);
|
||||
trainedModelProvider.getTrainedModel(modelId, true, false, future);
|
||||
TrainedModelConfig configWithDefinition = future.actionGet();
|
||||
|
||||
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
|
||||
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
|
||||
|
||||
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
|
||||
trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
|
||||
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
|
||||
|
||||
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
|
||||
|
|
|
@ -33,7 +33,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
|
|||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||
// Should be OK as we don't make any client calls
|
||||
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
|
||||
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
|
||||
TrainedModelConfig config = future.actionGet();
|
||||
|
||||
config.ensureParsedDefinition(xContentRegistry());
|
||||
|
|
|
@ -34,11 +34,10 @@
|
|||
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
|
||||
"default":true
|
||||
},
|
||||
"include_model_definition":{
|
||||
"type":"boolean",
|
||||
"include":{
|
||||
"type":"string",
|
||||
"required":false,
|
||||
"description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.",
|
||||
"default":false
|
||||
"description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none."
|
||||
},
|
||||
"decompress_definition":{
|
||||
"type":"boolean",
|
||||
|
|
|
@ -1,6 +1,24 @@
|
|||
setup:
|
||||
- skip:
|
||||
features: headers
|
||||
features:
|
||||
- headers
|
||||
- allowed_warnings
|
||||
- do:
|
||||
allowed_warnings:
|
||||
- "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template"
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
index:
|
||||
id: trained_model_metadata-a-regression-model-0
|
||||
index: .ml-inference-000003
|
||||
body:
|
||||
model_id: "a-regression-model-0"
|
||||
doc_type: "trained_model_metadata"
|
||||
total_feature_importance:
|
||||
- { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
|
||||
- { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
|
@ -548,6 +566,20 @@ setup:
|
|||
- match: { count: 12 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
|
||||
---
|
||||
"Test get models with include total feature importance":
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-regression-model-*"
|
||||
include: "total_feature_importance"
|
||||
- match: { count: 2 }
|
||||
- length: { trained_model_configs: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
|
||||
- is_true: trained_model_configs.0.metadata.total_feature_importance
|
||||
- length: { trained_model_configs.0.metadata.total_feature_importance: 2 }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
|
||||
- is_false: trained_model_configs.1.metadata.total_feature_importance
|
||||
|
||||
---
|
||||
"Test delete given unused trained model":
|
||||
- do:
|
||||
ml.delete_trained_model:
|
||||
|
@ -824,7 +856,7 @@ setup:
|
|||
ml.get_trained_models:
|
||||
model_id: "a-regression-model-1"
|
||||
for_export: true
|
||||
include_model_definition: true
|
||||
include: "definition"
|
||||
decompress_definition: false
|
||||
|
||||
- match: { trained_model_configs.0.description: "empty model for tests" }
|
||||
|
|
Loading…
Reference in New Issue