[7.x] [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) (#62620)
* [ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922) Adds new flag include to the get trained models API The flag initially has two valid values: definition, total_feature_importance. Consequently, the old include_model_definition flag is now deprecated. When total_feature_importance is included, the total_feature_importance field is included in the model metadata object. Including definition is the same as previously setting include_model_definition=true. * fixing test * Update x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
This commit is contained in:
parent
e1a4a3073a
commit
e163559e4c
|
@ -779,9 +779,9 @@ final class MLRequestConverters {
|
||||||
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
|
params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
|
||||||
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
|
Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
|
||||||
}
|
}
|
||||||
if (getTrainedModelsRequest.getIncludeDefinition() != null) {
|
if (getTrainedModelsRequest.getIncludes().isEmpty() == false) {
|
||||||
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
|
params.putParam(GetTrainedModelsRequest.INCLUDE,
|
||||||
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
|
Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes()));
|
||||||
}
|
}
|
||||||
if (getTrainedModelsRequest.getTags() != null) {
|
if (getTrainedModelsRequest.getTags() != null) {
|
||||||
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
|
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
|
||||||
|
|
|
@ -26,21 +26,26 @@ import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
public class GetTrainedModelsRequest implements Validatable {
|
public class GetTrainedModelsRequest implements Validatable {
|
||||||
|
|
||||||
|
private static final String DEFINITION = "definition";
|
||||||
|
private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||||
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
||||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
|
||||||
public static final String FOR_EXPORT = "for_export";
|
public static final String FOR_EXPORT = "for_export";
|
||||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||||
public static final String TAGS = "tags";
|
public static final String TAGS = "tags";
|
||||||
|
public static final String INCLUDE = "include";
|
||||||
|
|
||||||
private final List<String> ids;
|
private final List<String> ids;
|
||||||
private Boolean allowNoMatch;
|
private Boolean allowNoMatch;
|
||||||
private Boolean includeDefinition;
|
private Set<String> includes = new HashSet<>();
|
||||||
private Boolean decompressDefinition;
|
private Boolean decompressDefinition;
|
||||||
private Boolean forExport;
|
private Boolean forExport;
|
||||||
private PageParams pageParams;
|
private PageParams pageParams;
|
||||||
|
@ -86,19 +91,32 @@ public class GetTrainedModelsRequest implements Validatable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Boolean getIncludeDefinition() {
|
public Set<String> getIncludes() {
|
||||||
return includeDefinition;
|
return Collections.unmodifiableSet(includes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public GetTrainedModelsRequest includeDefinition() {
|
||||||
|
this.includes.add(DEFINITION);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GetTrainedModelsRequest includeTotalFeatureImportance() {
|
||||||
|
this.includes.add(TOTAL_FEATURE_IMPORTANCE);
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether to include the full model definition.
|
* Whether to include the full model definition.
|
||||||
*
|
*
|
||||||
* The full model definition can be very large.
|
* The full model definition can be very large.
|
||||||
*
|
* @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()}
|
||||||
* @param includeDefinition If {@code true}, the definition is included.
|
* @param includeDefinition If {@code true}, the definition is included.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
|
public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
|
||||||
this.includeDefinition = includeDefinition;
|
if (includeDefinition != null && includeDefinition) {
|
||||||
|
return this.includeDefinition();
|
||||||
|
}
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -173,13 +191,13 @@ public class GetTrainedModelsRequest implements Validatable {
|
||||||
return Objects.equals(ids, other.ids)
|
return Objects.equals(ids, other.ids)
|
||||||
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
||||||
&& Objects.equals(decompressDefinition, other.decompressDefinition)
|
&& Objects.equals(decompressDefinition, other.decompressDefinition)
|
||||||
&& Objects.equals(includeDefinition, other.includeDefinition)
|
&& Objects.equals(includes, other.includes)
|
||||||
&& Objects.equals(forExport, other.forExport)
|
&& Objects.equals(forExport, other.forExport)
|
||||||
&& Objects.equals(pageParams, other.pageParams);
|
&& Objects.equals(pageParams, other.pageParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
|
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParseException;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class TotalFeatureImportance implements ToXContentObject {
|
||||||
|
|
||||||
|
private static final String NAME = "total_feature_importance";
|
||||||
|
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
|
||||||
|
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||||
|
public static final ParseField CLASSES = new ParseField("classes");
|
||||||
|
public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
|
||||||
|
public static final ParseField MIN = new ParseField("min");
|
||||||
|
public static final ParseField MAX = new ParseField("max");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||||
|
true,
|
||||||
|
a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
|
||||||
|
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
|
||||||
|
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TotalFeatureImportance fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public final String featureName;
|
||||||
|
public final Importance importance;
|
||||||
|
public final List<ClassImportance> classImportances;
|
||||||
|
|
||||||
|
TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
|
||||||
|
this.featureName = featureName;
|
||||||
|
this.importance = importance;
|
||||||
|
this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
||||||
|
if (importance != null) {
|
||||||
|
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||||
|
}
|
||||||
|
if (classImportances.isEmpty() == false) {
|
||||||
|
builder.field(CLASSES.getPreferredName(), classImportances);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
TotalFeatureImportance that = (TotalFeatureImportance) o;
|
||||||
|
return Objects.equals(that.importance, importance)
|
||||||
|
&& Objects.equals(featureName, that.featureName)
|
||||||
|
&& Objects.equals(classImportances, that.classImportances);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(featureName, importance, classImportances);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Importance implements ToXContentObject {
|
||||||
|
private static final String NAME = "importance";
|
||||||
|
|
||||||
|
public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||||
|
true,
|
||||||
|
a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
|
||||||
|
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final double meanMagnitude;
|
||||||
|
private final double min;
|
||||||
|
private final double max;
|
||||||
|
|
||||||
|
public Importance(double meanMagnitude, double min, double max) {
|
||||||
|
this.meanMagnitude = meanMagnitude;
|
||||||
|
this.min = min;
|
||||||
|
this.max = max;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
Importance that = (Importance) o;
|
||||||
|
return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
|
||||||
|
Double.compare(that.min, min) == 0 &&
|
||||||
|
Double.compare(that.max, max) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(meanMagnitude, min, max);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||||
|
builder.field(MIN.getPreferredName(), min);
|
||||||
|
builder.field(MAX.getPreferredName(), max);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class ClassImportance implements ToXContentObject {
|
||||||
|
private static final String NAME = "total_class_importance";
|
||||||
|
|
||||||
|
public static final ParseField CLASS_NAME = new ParseField("class_name");
|
||||||
|
public static final ParseField IMPORTANCE = new ParseField("importance");
|
||||||
|
|
||||||
|
public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
|
||||||
|
true,
|
||||||
|
a -> new ClassImportance(a[0], (Importance)a[1]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
|
||||||
|
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
|
||||||
|
return p.text();
|
||||||
|
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
|
||||||
|
return p.numberValue();
|
||||||
|
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
|
||||||
|
return p.booleanValue();
|
||||||
|
}
|
||||||
|
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
|
||||||
|
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
|
||||||
|
PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ClassImportance fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public final Object className;
|
||||||
|
public final Importance importance;
|
||||||
|
|
||||||
|
ClassImportance(Object className, Importance importance) {
|
||||||
|
this.className = className;
|
||||||
|
this.importance = importance;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(CLASS_NAME.getPreferredName(), className);
|
||||||
|
builder.field(IMPORTANCE.getPreferredName(), importance);
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
ClassImportance that = (ClassImportance) o;
|
||||||
|
return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(className, importance);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -894,7 +894,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
||||||
GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
|
GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
|
||||||
.setAllowNoMatch(false)
|
.setAllowNoMatch(false)
|
||||||
.setDecompressDefinition(true)
|
.setDecompressDefinition(true)
|
||||||
.setIncludeDefinition(false)
|
.includeDefinition()
|
||||||
.setTags("tag1", "tag2")
|
.setTags("tag1", "tag2")
|
||||||
.setPageParams(new PageParams(100, 300));
|
.setPageParams(new PageParams(100, 300));
|
||||||
|
|
||||||
|
@ -908,7 +908,7 @@ public class MLRequestConvertersTests extends ESTestCase {
|
||||||
hasEntry("allow_no_match", "false"),
|
hasEntry("allow_no_match", "false"),
|
||||||
hasEntry("decompress_definition", "true"),
|
hasEntry("decompress_definition", "true"),
|
||||||
hasEntry("tags", "tag1,tag2"),
|
hasEntry("tags", "tag1,tag2"),
|
||||||
hasEntry("include_model_definition", "false")
|
hasEntry("include", "definition")
|
||||||
));
|
));
|
||||||
assertNull(request.getEntity());
|
assertNull(request.getEntity());
|
||||||
}
|
}
|
||||||
|
|
|
@ -2257,7 +2257,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
|
|
||||||
{
|
{
|
||||||
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
||||||
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true),
|
new GetTrainedModelsRequest(modelIdPrefix + 0)
|
||||||
|
.setDecompressDefinition(true)
|
||||||
|
.includeDefinition()
|
||||||
|
.includeTotalFeatureImportance(),
|
||||||
machineLearningClient::getTrainedModels,
|
machineLearningClient::getTrainedModels,
|
||||||
machineLearningClient::getTrainedModelsAsync);
|
machineLearningClient::getTrainedModelsAsync);
|
||||||
|
|
||||||
|
@ -2268,7 +2271,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
|
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
|
||||||
|
|
||||||
getTrainedModelsResponse = execute(
|
getTrainedModelsResponse = execute(
|
||||||
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true),
|
new GetTrainedModelsRequest(modelIdPrefix + 0)
|
||||||
|
.setDecompressDefinition(false)
|
||||||
|
.includeTotalFeatureImportance()
|
||||||
|
.includeDefinition(),
|
||||||
machineLearningClient::getTrainedModels,
|
machineLearningClient::getTrainedModels,
|
||||||
machineLearningClient::getTrainedModelsAsync);
|
machineLearningClient::getTrainedModelsAsync);
|
||||||
|
|
||||||
|
@ -2279,7 +2285,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
|
assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
|
||||||
|
|
||||||
getTrainedModelsResponse = execute(
|
getTrainedModelsResponse = execute(
|
||||||
new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false),
|
new GetTrainedModelsRequest(modelIdPrefix + 0)
|
||||||
|
.setDecompressDefinition(false),
|
||||||
machineLearningClient::getTrainedModels,
|
machineLearningClient::getTrainedModels,
|
||||||
machineLearningClient::getTrainedModelsAsync);
|
machineLearningClient::getTrainedModelsAsync);
|
||||||
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
|
assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
|
||||||
|
|
|
@ -3694,11 +3694,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
// tag::get-trained-models-request
|
// tag::get-trained-models-request
|
||||||
GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
|
GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
|
||||||
.setPageParams(new PageParams(0, 1)) // <2>
|
.setPageParams(new PageParams(0, 1)) // <2>
|
||||||
.setIncludeDefinition(false) // <3>
|
.includeDefinition() // <3>
|
||||||
.setDecompressDefinition(false) // <4>
|
.includeTotalFeatureImportance() // <4>
|
||||||
.setAllowNoMatch(true) // <5>
|
.setDecompressDefinition(false) // <5>
|
||||||
.setTags("regression") // <6>
|
.setAllowNoMatch(true) // <6>
|
||||||
.setForExport(false); // <7>
|
.setTags("regression") // <7>
|
||||||
|
.setForExport(false); // <8>
|
||||||
// end::get-trained-models-request
|
// end::get-trained-models-request
|
||||||
request.setTags((List<String>)null);
|
request.setTags((List<String>)null);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
|
public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
|
||||||
|
|
||||||
|
|
||||||
|
public static TotalFeatureImportance randomInstance() {
|
||||||
|
return new TotalFeatureImportance(
|
||||||
|
randomAlphaOfLength(10),
|
||||||
|
randomBoolean() ? null : randomImportance(),
|
||||||
|
randomBoolean() ?
|
||||||
|
null :
|
||||||
|
Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
|
||||||
|
.limit(randomIntBetween(1, 10))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TotalFeatureImportance.Importance randomImportance() {
|
||||||
|
return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TotalFeatureImportance createTestInstance() {
|
||||||
|
return randomInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return TotalFeatureImportance.fromXContent(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -22,26 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models.
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
include-tagged::{doc-tests-file}[{api}-request]
|
include-tagged::{doc-tests-file}[{api}-request]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
<1> Constructing a new GET request referencing an existing Trained Model
|
<1> Constructing a new GET request referencing an existing trained model
|
||||||
<2> Set the paging parameters
|
<2> Set the paging parameters
|
||||||
<3> Indicate if the complete model definition should be included
|
<3> Indicate if the complete model definition should be included
|
||||||
<4> Should the definition be fully decompressed on GET
|
<4> Indicate if the total feature importance for the features used in training
|
||||||
<5> Allow empty response if no Trained Models match the provided ID patterns.
|
should be included in the model `metadata` field.
|
||||||
If false, an error will be thrown if no Trained Models match the
|
<5> Should the definition be fully decompressed on GET
|
||||||
|
<6> Allow empty response if no trained models match the provided ID patterns.
|
||||||
|
If false, an error will be thrown if no trained models match the
|
||||||
ID patterns.
|
ID patterns.
|
||||||
<6> An optional list of tags used to narrow the model search. A Trained Model
|
<7> An optional list of tags used to narrow the model search. A trained model
|
||||||
can have many tags or none. The trained models in the response will
|
can have many tags or none. The trained models in the response will
|
||||||
contain all the provided tags.
|
contain all the provided tags.
|
||||||
<7> Optional boolean value indicating if certain fields should be removed on
|
<8> Optional boolean value for requesting the trained model in a format that can
|
||||||
retrieval. This is useful for getting the trained model in a format that
|
then be put into another cluster. Certain fields that can only be set when
|
||||||
can then be put into another cluster.
|
the model is imported are removed.
|
||||||
|
|
||||||
include::../execution.asciidoc[]
|
include::../execution.asciidoc[]
|
||||||
|
|
||||||
[id="{upid}-{api}-response"]
|
[id="{upid}-{api}-response"]
|
||||||
==== Response
|
==== Response
|
||||||
|
|
||||||
The returned +{response}+ contains the requested Trained Model.
|
The returned +{response}+ contains the requested trained model.
|
||||||
|
|
||||||
["source","java",subs="attributes,callouts,macros"]
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
|
|
|
@ -29,18 +29,19 @@ experimental[]
|
||||||
[[ml-get-inference-prereq]]
|
[[ml-get-inference-prereq]]
|
||||||
== {api-prereq-title}
|
== {api-prereq-title}
|
||||||
|
|
||||||
Required privileges which should be added to a custom role:
|
If the {es} {security-features} are enabled, you must have the following
|
||||||
|
privileges:
|
||||||
|
|
||||||
* cluster: `monitor_ml`
|
* cluster: `monitor_ml`
|
||||||
|
|
||||||
For more information, see <<security-privileges>> and
|
For more information, see <<security-privileges>> and
|
||||||
{ml-docs-setup-privileges}.
|
{ml-docs-setup-privileges}.
|
||||||
|
|
||||||
|
|
||||||
[[ml-get-inference-desc]]
|
[[ml-get-inference-desc]]
|
||||||
== {api-description-title}
|
== {api-description-title}
|
||||||
|
|
||||||
You can get information for multiple trained models in a single API request by
|
You can get information for multiple trained models in a single API request by
|
||||||
using a comma-separated list of model IDs or a wildcard expression.
|
using a comma-separated list of model IDs or a wildcard expression.
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression.
|
||||||
== {api-path-parms-title}
|
== {api-path-parms-title}
|
||||||
|
|
||||||
`<model_id>`::
|
`<model_id>`::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
|
||||||
== {api-query-parms-title}
|
== {api-query-parms-title}
|
||||||
|
|
||||||
`allow_no_match`::
|
`allow_no_match`::
|
||||||
(Optional, boolean)
|
(Optional, boolean)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models]
|
||||||
|
|
||||||
`decompress_definition`::
|
`decompress_definition`::
|
||||||
(Optional, boolean)
|
(Optional, boolean)
|
||||||
Specifies whether the included model definition should be returned as a JSON map
|
Specifies whether the included model definition should be returned as a JSON map
|
||||||
(`true`) or in a custom compressed format (`false`). Defaults to `true`.
|
(`true`) or in a custom compressed format (`false`). Defaults to `true`.
|
||||||
|
|
||||||
`for_export`::
|
`for_export`::
|
||||||
|
@ -71,17 +72,21 @@ retrieval. This allows the model to be in an acceptable format to be retrieved
|
||||||
and then added to another cluster. Default is false.
|
and then added to another cluster. Default is false.
|
||||||
|
|
||||||
`from`::
|
`from`::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models]
|
||||||
|
|
||||||
`include_model_definition`::
|
`include`::
|
||||||
(Optional, boolean)
|
(Optional, string)
|
||||||
Specifies whether the model definition is returned in the response. Defaults to
|
A comma delimited string of optional fields to include in the response body.
|
||||||
`false`. When `true`, only a single model must match the ID patterns provided.
|
Valid options are:
|
||||||
Otherwise, a bad request is returned.
|
- `definition`: Includes the model definition
|
||||||
|
- `total_feature_importance`: Includes the total feature importance for the
|
||||||
|
training data set. This field is available in the `metadata` field in the
|
||||||
|
response body.
|
||||||
|
Default is empty, indicating including no optional fields.
|
||||||
|
|
||||||
`size`::
|
`size`::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models]
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models]
|
||||||
|
|
||||||
`tags`::
|
`tags`::
|
||||||
|
@ -94,7 +99,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags]
|
||||||
|
|
||||||
`trained_model_configs`::
|
`trained_model_configs`::
|
||||||
(array)
|
(array)
|
||||||
An array of trained model resources, which are sorted by the `model_id` value in
|
An array of trained model resources, which are sorted by the `model_id` value in
|
||||||
ascending order.
|
ascending order.
|
||||||
+
|
+
|
||||||
.Properties of trained model resources
|
.Properties of trained model resources
|
||||||
|
@ -132,8 +137,86 @@ The license level of the trained model.
|
||||||
|
|
||||||
`metadata`:::
|
`metadata`:::
|
||||||
(object)
|
(object)
|
||||||
An object containing metadata about the trained model. For example, models
|
An object containing metadata about the trained model. For example, models
|
||||||
created by {dfanalytics} contain `analysis_config` and `input` objects.
|
created by {dfanalytics} contain `analysis_config` and `input` objects.
|
||||||
|
.Properties of metadata
|
||||||
|
[%collapsible%open]
|
||||||
|
=====
|
||||||
|
`total_feature_importance`:::
|
||||||
|
(array)
|
||||||
|
An array of the total feature importance for each feature used from
|
||||||
|
the training data set. This array of objects is returned if {dfanalytics} trained
|
||||||
|
the model and the request includes `total_feature_importance` in the `include`
|
||||||
|
request parameter.
|
||||||
|
+
|
||||||
|
.Properties of total feature importance
|
||||||
|
[%collapsible%open]
|
||||||
|
======
|
||||||
|
|
||||||
|
`feature_name`:::
|
||||||
|
(string)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name]
|
||||||
|
|
||||||
|
`importance`:::
|
||||||
|
(object)
|
||||||
|
A collection of feature importance statistics related to the training data set for this particular feature.
|
||||||
|
+
|
||||||
|
.Properties of feature importance
|
||||||
|
[%collapsible%open]
|
||||||
|
=======
|
||||||
|
`mean_magnitude`:::
|
||||||
|
(double)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
|
||||||
|
|
||||||
|
`max`:::
|
||||||
|
(int)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
|
||||||
|
|
||||||
|
`min`:::
|
||||||
|
(int)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
|
||||||
|
|
||||||
|
=======
|
||||||
|
|
||||||
|
`classes`:::
|
||||||
|
(array)
|
||||||
|
If the trained model is a classification model, feature importance statistics are gathered
|
||||||
|
per target class value.
|
||||||
|
+
|
||||||
|
.Properties of class feature importance
|
||||||
|
[%collapsible%open]
|
||||||
|
|
||||||
|
=======
|
||||||
|
|
||||||
|
`class_name`:::
|
||||||
|
(string)
|
||||||
|
The target class value. Could be a string, boolean, or number.
|
||||||
|
|
||||||
|
`importance`:::
|
||||||
|
(object)
|
||||||
|
A collection of feature importance statistics related to the training data set for this particular feature.
|
||||||
|
+
|
||||||
|
.Properties of feature importance
|
||||||
|
[%collapsible%open]
|
||||||
|
========
|
||||||
|
`mean_magnitude`:::
|
||||||
|
(double)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
|
||||||
|
|
||||||
|
`max`:::
|
||||||
|
(int)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
|
||||||
|
|
||||||
|
`min`:::
|
||||||
|
(int)
|
||||||
|
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
|
||||||
|
|
||||||
|
========
|
||||||
|
|
||||||
|
=======
|
||||||
|
|
||||||
|
======
|
||||||
|
=====
|
||||||
|
|
||||||
`model_id`:::
|
`model_id`:::
|
||||||
(string)
|
(string)
|
||||||
|
@ -152,13 +235,13 @@ The {es} version number in which the trained model was created.
|
||||||
== {api-response-codes-title}
|
== {api-response-codes-title}
|
||||||
|
|
||||||
`400`::
|
`400`::
|
||||||
If `include_model_definition` is `true`, this code indicates that more than
|
If `include_model_definition` is `true`, this code indicates that more than
|
||||||
one models match the ID pattern.
|
one models match the ID pattern.
|
||||||
|
|
||||||
`404` (Missing resources)::
|
`404` (Missing resources)::
|
||||||
If `allow_no_match` is `false`, this code indicates that there are no
|
If `allow_no_match` is `false`, this code indicates that there are no
|
||||||
resources that match the request or only partial matches for the request.
|
resources that match the request or only partial matches for the request.
|
||||||
|
|
||||||
|
|
||||||
[[ml-get-inference-example]]
|
[[ml-get-inference-example]]
|
||||||
== {api-examples-title}
|
== {api-examples-title}
|
||||||
|
|
|
@ -780,6 +780,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
|
||||||
used to train the model, which defaults to `<dependent_variable>_prediction`.
|
used to train the model, which defaults to `<dependent_variable>_prediction`.
|
||||||
end::inference-config-results-field-processor[]
|
end::inference-config-results-field-processor[]
|
||||||
|
|
||||||
|
tag::inference-metadata-feature-importance-feature-name[]
|
||||||
|
The training feature name for which this importance was calculated.
|
||||||
|
end::inference-metadata-feature-importance-feature-name[]
|
||||||
|
tag::inference-metadata-feature-importance-magnitude[]
|
||||||
|
The average magnitude of this feature across all the training data.
|
||||||
|
This value is the average of the absolute values of the importance
|
||||||
|
for this feature.
|
||||||
|
end::inference-metadata-feature-importance-magnitude[]
|
||||||
|
tag::inference-metadata-feature-importance-max[]
|
||||||
|
The maximum importance value across all the training data for this
|
||||||
|
feature.
|
||||||
|
end::inference-metadata-feature-importance-max[]
|
||||||
|
tag::inference-metadata-feature-importance-min[]
|
||||||
|
The minimum importance value across all the training data for this
|
||||||
|
feature.
|
||||||
|
end::inference-metadata-feature-importance-min[]
|
||||||
|
|
||||||
tag::influencers[]
|
tag::influencers[]
|
||||||
A comma separated list of influencer field names. Typically these can be the by,
|
A comma separated list of influencer field names. Typically these can be the by,
|
||||||
over, or partition fields that are used in the detector configuration. You might
|
over, or partition fields that are used in the detector configuration. You might
|
||||||
|
|
|
@ -10,15 +10,19 @@ import org.elasticsearch.action.ActionType;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.io.stream.StreamInput;
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.common.util.set.Sets;
|
||||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
|
||||||
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
|
||||||
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
|
||||||
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
|
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
|
||||||
|
@ -32,23 +36,60 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
||||||
|
|
||||||
public static class Request extends AbstractGetResourcesRequest {
|
public static class Request extends AbstractGetResourcesRequest {
|
||||||
|
|
||||||
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
|
static final String DEFINITION = "definition";
|
||||||
|
static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||||
|
private static final Set<String> KNOWN_INCLUDES;
|
||||||
|
static {
|
||||||
|
HashSet<String> includes = new HashSet<>(2, 1.0f);
|
||||||
|
includes.add(DEFINITION);
|
||||||
|
includes.add(TOTAL_FEATURE_IMPORTANCE);
|
||||||
|
KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
|
||||||
|
}
|
||||||
|
public static final ParseField INCLUDE = new ParseField("include");
|
||||||
|
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||||
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
|
||||||
public static final ParseField TAGS = new ParseField("tags");
|
public static final ParseField TAGS = new ParseField("tags");
|
||||||
|
|
||||||
private final boolean includeModelDefinition;
|
private final Set<String> includes;
|
||||||
private final List<String> tags;
|
private final List<String> tags;
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public Request(String id, boolean includeModelDefinition, List<String> tags) {
|
public Request(String id, boolean includeModelDefinition, List<String> tags) {
|
||||||
setResourceId(id);
|
setResourceId(id);
|
||||||
setAllowNoResources(true);
|
setAllowNoResources(true);
|
||||||
this.includeModelDefinition = includeModelDefinition;
|
|
||||||
this.tags = tags == null ? Collections.emptyList() : tags;
|
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||||
|
if (includeModelDefinition) {
|
||||||
|
this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
|
||||||
|
} else {
|
||||||
|
this.includes = Collections.emptySet();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Request(String id, List<String> tags, Set<String> includes) {
|
||||||
|
setResourceId(id);
|
||||||
|
setAllowNoResources(true);
|
||||||
|
this.tags = tags == null ? Collections.emptyList() : tags;
|
||||||
|
this.includes = includes == null ? Collections.emptySet() : includes;
|
||||||
|
Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
|
||||||
|
if (unknownIncludes.isEmpty() == false) {
|
||||||
|
throw ExceptionsHelper.badRequestException(
|
||||||
|
"unknown [include] parameters {}. Valid options are {}",
|
||||||
|
unknownIncludes,
|
||||||
|
KNOWN_INCLUDES);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Request(StreamInput in) throws IOException {
|
public Request(StreamInput in) throws IOException {
|
||||||
super(in);
|
super(in);
|
||||||
this.includeModelDefinition = in.readBoolean();
|
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||||
|
this.includes = in.readSet(StreamInput::readString);
|
||||||
|
} else {
|
||||||
|
Set<String> includes = new HashSet<>();
|
||||||
|
if (in.readBoolean()) {
|
||||||
|
includes.add(DEFINITION);
|
||||||
|
}
|
||||||
|
this.includes = includes;
|
||||||
|
}
|
||||||
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
this.tags = in.readStringList();
|
this.tags = in.readStringList();
|
||||||
} else {
|
} else {
|
||||||
|
@ -62,7 +103,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isIncludeModelDefinition() {
|
public boolean isIncludeModelDefinition() {
|
||||||
return includeModelDefinition;
|
return this.includes.contains(DEFINITION);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isIncludeTotalFeatureImportance() {
|
||||||
|
return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<String> getTags() {
|
public List<String> getTags() {
|
||||||
|
@ -72,7 +117,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
||||||
@Override
|
@Override
|
||||||
public void writeTo(StreamOutput out) throws IOException {
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
super.writeTo(out);
|
super.writeTo(out);
|
||||||
out.writeBoolean(includeModelDefinition);
|
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
|
||||||
|
out.writeCollection(this.includes, StreamOutput::writeString);
|
||||||
|
} else {
|
||||||
|
out.writeBoolean(this.includes.contains(DEFINITION));
|
||||||
|
}
|
||||||
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
||||||
out.writeStringCollection(tags);
|
out.writeStringCollection(tags);
|
||||||
}
|
}
|
||||||
|
@ -80,7 +129,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
|
return Objects.hash(super.hashCode(), includes, tags);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -92,7 +141,18 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
Request other = (Request) obj;
|
Request other = (Request) obj;
|
||||||
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
|
return super.equals(obj) && Objects.equals(includes, other.includes) && Objects.equals(tags, other.tags);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Request{" +
|
||||||
|
"includes=" + includes +
|
||||||
|
", tags=" + tags +
|
||||||
|
", page=" + getPageParams() +
|
||||||
|
", id=" + getResourceId() +
|
||||||
|
", allow_missing=" + isAllowNoResources() +
|
||||||
|
'}';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
import org.elasticsearch.xpack.core.ml.utils.MlStrings;
|
||||||
|
@ -39,6 +40,7 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
import static org.elasticsearch.action.ValidateActions.addValidationError;
|
||||||
|
@ -51,6 +53,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
|
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
|
||||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||||
public static final String FOR_EXPORT = "for_export";
|
public static final String FOR_EXPORT = "for_export";
|
||||||
|
public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
|
||||||
|
private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
|
||||||
|
|
||||||
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
||||||
|
|
||||||
|
@ -419,7 +423,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
|
this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
|
||||||
this.description = config.getDescription();
|
this.description = config.getDescription();
|
||||||
this.tags = config.getTags();
|
this.tags = config.getTags();
|
||||||
this.metadata = config.getMetadata();
|
this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata());
|
||||||
this.input = config.getInput();
|
this.input = config.getInput();
|
||||||
this.estimatedOperations = config.estimatedOperations;
|
this.estimatedOperations = config.estimatedOperations;
|
||||||
this.estimatedHeapMemory = config.estimatedHeapMemory;
|
this.estimatedHeapMemory = config.estimatedHeapMemory;
|
||||||
|
@ -471,6 +475,18 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setFeatureImportance(List<TotalFeatureImportance> totalFeatureImportance) {
|
||||||
|
if (totalFeatureImportance == null) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
if (this.metadata == null) {
|
||||||
|
this.metadata = new HashMap<>();
|
||||||
|
}
|
||||||
|
this.metadata.put(TOTAL_FEATURE_IMPORTANCE,
|
||||||
|
totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
|
public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
|
||||||
if (definition == null) {
|
if (definition == null) {
|
||||||
return this;
|
return this;
|
||||||
|
@ -627,6 +643,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
ESTIMATED_OPERATIONS.getPreferredName(),
|
ESTIMATED_OPERATIONS.getPreferredName(),
|
||||||
validationException);
|
validationException);
|
||||||
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
|
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
|
||||||
|
if (metadata != null) {
|
||||||
|
validationException = checkIllegalSetting(
|
||||||
|
metadata.get(TOTAL_FEATURE_IMPORTANCE),
|
||||||
|
METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE,
|
||||||
|
validationException);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (validationException != null) {
|
if (validationException != null) {
|
||||||
|
|
|
@ -20,8 +20,11 @@ import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
|
@ -81,16 +84,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
return builder.map(asMap());
|
||||||
builder.field(FEATURE_NAME.getPreferredName(), featureName);
|
|
||||||
if (importance != null) {
|
|
||||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
|
||||||
}
|
|
||||||
if (classImportances.isEmpty() == false) {
|
|
||||||
builder.field(CLASSES.getPreferredName(), classImportances);
|
|
||||||
}
|
|
||||||
builder.endObject();
|
|
||||||
return builder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -103,6 +97,18 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
&& Objects.equals(classImportances, that.classImportances);
|
&& Objects.equals(classImportances, that.classImportances);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Map<String, Object> asMap() {
|
||||||
|
Map<String, Object> map = new LinkedHashMap<>();
|
||||||
|
map.put(FEATURE_NAME.getPreferredName(), featureName);
|
||||||
|
if (importance != null) {
|
||||||
|
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
|
||||||
|
}
|
||||||
|
if (classImportances.isEmpty() == false) {
|
||||||
|
map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList()));
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(featureName, importance, classImportances);
|
return Objects.hash(featureName, importance, classImportances);
|
||||||
|
@ -165,12 +171,15 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
return builder.map(asMap());
|
||||||
builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
}
|
||||||
builder.field(MIN.getPreferredName(), min);
|
|
||||||
builder.field(MAX.getPreferredName(), max);
|
private Map<String, Object> asMap() {
|
||||||
builder.endObject();
|
Map<String, Object> map = new LinkedHashMap<>();
|
||||||
return builder;
|
map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
|
||||||
|
map.put(MIN.getPreferredName(), min);
|
||||||
|
map.put(MAX.getPreferredName(), max);
|
||||||
|
return map;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,11 +238,14 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
return builder.map(asMap());
|
||||||
builder.field(CLASS_NAME.getPreferredName(), className);
|
}
|
||||||
builder.field(IMPORTANCE.getPreferredName(), importance);
|
|
||||||
builder.endObject();
|
private Map<String, Object> asMap() {
|
||||||
return builder;
|
Map<String, Object> map = new LinkedHashMap<>();
|
||||||
|
map.put(CLASS_NAME.getPreferredName(), className);
|
||||||
|
map.put(IMPORTANCE.getPreferredName(), importance.asMap());
|
||||||
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -53,6 +53,10 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
|
||||||
return NAME + "-" + modelId;
|
return NAME + "-" + modelId;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String modelId(String docId) {
|
||||||
|
return docId.substring(NAME.length() + 1);
|
||||||
|
}
|
||||||
|
|
||||||
private final List<TotalFeatureImportance> totalFeatureImportances;
|
private final List<TotalFeatureImportance> totalFeatureImportances;
|
||||||
private final String modelId;
|
private final String modelId;
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ public final class Messages {
|
||||||
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
|
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
|
||||||
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
|
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
|
||||||
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
||||||
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]";
|
public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
|
||||||
public static final String INFERENCE_CANNOT_DELETE_MODEL =
|
public static final String INFERENCE_CANNOT_DELETE_MODEL =
|
||||||
"Unable to delete model [{0}]";
|
"Unable to delete model [{0}]";
|
||||||
public static final String MODEL_DEFINITION_TRUNCATED =
|
public static final String MODEL_DEFINITION_TRUNCATED =
|
||||||
|
|
|
@ -5,19 +5,28 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.core.ml.action;
|
package org.elasticsearch.xpack.core.ml.action;
|
||||||
|
|
||||||
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.io.stream.Writeable;
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
|
||||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||||
|
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
|
||||||
|
|
||||||
public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Request createTestInstance() {
|
protected Request createTestInstance() {
|
||||||
Request request = new Request(randomAlphaOfLength(20),
|
Request request = new Request(randomAlphaOfLength(20),
|
||||||
randomBoolean(),
|
|
||||||
randomBoolean() ? null :
|
randomBoolean() ? null :
|
||||||
randomList(10, () -> randomAlphaOfLength(10)));
|
randomList(10, () -> randomAlphaOfLength(10)),
|
||||||
|
randomBoolean() ? null :
|
||||||
|
Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
|
||||||
|
.limit(4)
|
||||||
|
.collect(Collectors.toSet()));
|
||||||
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
|
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
|
||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
@ -26,4 +35,22 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
|
||||||
protected Writeable.Reader<Request> instanceReader() {
|
protected Writeable.Reader<Request> instanceReader() {
|
||||||
return Request::new;
|
return Request::new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Request mutateInstanceForVersion(Request instance, Version version) {
|
||||||
|
if (version.before(Version.V_7_10_0)) {
|
||||||
|
Set<String> includes = new HashSet<>();
|
||||||
|
if (instance.isIncludeModelDefinition()) {
|
||||||
|
includes.add(Request.DEFINITION);
|
||||||
|
}
|
||||||
|
Request request = new Request(
|
||||||
|
instance.getResourceId(),
|
||||||
|
version.before(Version.V_7_7_0) ? null : instance.getTags(),
|
||||||
|
includes);
|
||||||
|
request.setPageParams(instance.getPageParams());
|
||||||
|
request.setAllowNoResources(instance.isAllowNoResources());
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,11 +42,13 @@ import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasKey;
|
||||||
import static org.hamcrest.Matchers.startsWith;
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
|
|
||||||
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
||||||
|
@ -95,19 +97,21 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
||||||
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
|
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
|
||||||
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
|
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
|
||||||
assertThat(ids.v1(), equalTo(1L));
|
assertThat(ids.v1(), equalTo(1L));
|
||||||
|
String inferenceModelId = ids.v2().iterator().next();
|
||||||
|
|
||||||
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
|
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
|
||||||
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
|
trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
|
||||||
|
|
||||||
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
|
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
|
||||||
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
|
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
|
||||||
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||||
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||||
|
assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
|
||||||
|
|
||||||
PlainActionFuture<TrainedModelMetadata> getTrainedMetadataFuture = new PlainActionFuture<>();
|
PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
|
||||||
trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture);
|
trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
|
||||||
|
|
||||||
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet();
|
TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
|
||||||
assertThat(storedMetadata.getModelId(), startsWith(modelId));
|
assertThat(storedMetadata.getModelId(), startsWith(modelId));
|
||||||
assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
|
assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,7 +90,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
blockingCall(
|
||||||
|
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||||
|
getConfigHolder,
|
||||||
|
exceptionHolder);
|
||||||
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
||||||
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||||
assertThat(getConfigHolder.get(), equalTo(config));
|
assertThat(getConfigHolder.get(), equalTo(config));
|
||||||
|
@ -121,7 +124,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder);
|
blockingCall(listener ->
|
||||||
|
trainedModelProvider.getTrainedModel(modelId, false, false, listener),
|
||||||
|
getConfigHolder,
|
||||||
|
exceptionHolder);
|
||||||
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
|
||||||
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
assertThat(getConfigHolder.get(), is(not(nullValue())));
|
||||||
assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition));
|
assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition));
|
||||||
|
@ -132,7 +138,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
String modelId = "test-get-missing-trained-model-config";
|
String modelId = "test-get-missing-trained-model-config";
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
blockingCall(
|
||||||
|
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||||
|
getConfigHolder,
|
||||||
|
exceptionHolder);
|
||||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
assertThat(exceptionHolder.get().getMessage(),
|
assertThat(exceptionHolder.get().getMessage(),
|
||||||
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||||
|
@ -154,7 +163,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
.actionGet();
|
.actionGet();
|
||||||
|
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
blockingCall(
|
||||||
|
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||||
|
getConfigHolder,
|
||||||
|
exceptionHolder);
|
||||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
assertThat(exceptionHolder.get().getMessage(),
|
assertThat(exceptionHolder.get().getMessage(),
|
||||||
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||||
|
@ -193,7 +205,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
blockingCall(
|
||||||
|
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||||
|
getConfigHolder,
|
||||||
|
exceptionHolder);
|
||||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||||
|
@ -238,7 +253,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
blockingCall(
|
||||||
|
listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
|
||||||
|
getConfigHolder,
|
||||||
|
exceptionHolder);
|
||||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||||
|
|
|
@ -57,15 +57,25 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
|
||||||
}
|
}
|
||||||
|
|
||||||
if (request.isIncludeModelDefinition()) {
|
if (request.isIncludeModelDefinition()) {
|
||||||
provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap(
|
provider.getTrainedModel(
|
||||||
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
|
totalAndIds.v2().iterator().next(),
|
||||||
listener::onFailure
|
true,
|
||||||
));
|
request.isIncludeTotalFeatureImportance(),
|
||||||
|
ActionListener.wrap(
|
||||||
|
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
|
||||||
|
listener::onFailure
|
||||||
|
)
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap(
|
provider.getTrainedModels(
|
||||||
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
|
totalAndIds.v2(),
|
||||||
listener::onFailure
|
request.isAllowNoResources(),
|
||||||
));
|
request.isIncludeTotalFeatureImportance(),
|
||||||
|
ActionListener.wrap(
|
||||||
|
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
|
||||||
|
listener::onFailure
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
|
|
|
@ -82,7 +82,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
||||||
responseBuilder.setLicensed(true);
|
responseBuilder.setLicensed(true);
|
||||||
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
||||||
} else {
|
} else {
|
||||||
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
|
trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
|
||||||
trainedModelConfig -> {
|
trainedModelConfig -> {
|
||||||
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
|
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
|
||||||
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
|
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
|
||||||
|
|
|
@ -270,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void loadModel(String modelId, Consumer consumer) {
|
private void loadModel(String modelId, Consumer consumer) {
|
||||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
|
||||||
trainedModelConfig -> {
|
trainedModelConfig -> {
|
||||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||||
|
@ -306,7 +306,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||||
// by a simulated pipeline
|
// by a simulated pipeline
|
||||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
|
||||||
trainedModelConfig -> {
|
trainedModelConfig -> {
|
||||||
// Verify we can pull the model into memory without causing OOM
|
// Verify we can pull the model into memory without causing OOM
|
||||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||||
|
@ -434,7 +434,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||||
|
|
||||||
logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
|
logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
|
||||||
notification.getValue().model.getModelId()));
|
notification.getValue().model.getModelId()));
|
||||||
|
|
||||||
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
||||||
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
|
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||||
} finally {
|
} finally {
|
||||||
|
|
|
@ -89,9 +89,11 @@ import java.util.Arrays;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
@ -235,14 +237,14 @@ public class TrainedModelProvider {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) {
|
public void getTrainedModelMetadata(Collection<String> modelIds, ActionListener<Map<String, TrainedModelMetadata>> listener) {
|
||||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||||
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
||||||
.boolQuery()
|
.boolQuery()
|
||||||
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
|
.filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
|
||||||
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
|
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
|
||||||
TrainedModelMetadata.NAME))))
|
TrainedModelMetadata.NAME))))
|
||||||
.setSize(1)
|
.setSize(10_000)
|
||||||
// First find the latest index
|
// First find the latest index
|
||||||
.addSort("_index", SortOrder.DESC)
|
.addSort("_index", SortOrder.DESC)
|
||||||
.request();
|
.request();
|
||||||
|
@ -250,18 +252,20 @@ public class TrainedModelProvider {
|
||||||
searchResponse -> {
|
searchResponse -> {
|
||||||
if (searchResponse.getHits().getHits().length == 0) {
|
if (searchResponse.getHits().getHits().length == 0) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(),
|
HashMap<String, TrainedModelMetadata> map = new HashMap<>();
|
||||||
modelId,
|
for (SearchHit hit : searchResponse.getHits().getHits()) {
|
||||||
this::parseMetadataLenientlyFromSource);
|
String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId()));
|
||||||
listener.onResponse(metadataList.get(0));
|
map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId));
|
||||||
|
}
|
||||||
|
listener.onResponse(map);
|
||||||
},
|
},
|
||||||
e -> {
|
e -> {
|
||||||
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
listener.onFailure(e);
|
listener.onFailure(e);
|
||||||
|
@ -371,7 +375,7 @@ public class TrainedModelProvider {
|
||||||
// TODO Change this when we get more than just langIdent stored
|
// TODO Change this when we get more than just langIdent stored
|
||||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||||
try {
|
try {
|
||||||
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
|
TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
|
||||||
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
|
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
|
||||||
listener.onResponse(
|
listener.onResponse(
|
||||||
InferenceDefinition.builder()
|
InferenceDefinition.builder()
|
||||||
|
@ -434,18 +438,50 @@ public class TrainedModelProvider {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
public void getTrainedModel(final String modelId,
|
||||||
|
final boolean includeDefinition,
|
||||||
|
final boolean includeTotalFeatureImportance,
|
||||||
|
final ActionListener<TrainedModelConfig> finalListener) {
|
||||||
|
|
||||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||||
try {
|
try {
|
||||||
listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
|
finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
|
||||||
return;
|
return;
|
||||||
} catch (ElasticsearchException ex) {
|
} catch (ElasticsearchException ex) {
|
||||||
listener.onFailure(ex);
|
finalListener.onFailure(ex);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
|
||||||
|
modelBuilder -> {
|
||||||
|
if (includeTotalFeatureImportance == false) {
|
||||||
|
finalListener.onResponse(modelBuilder.build());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap(
|
||||||
|
metadata -> {
|
||||||
|
TrainedModelMetadata modelMetadata = metadata.get(modelId);
|
||||||
|
if (modelMetadata != null) {
|
||||||
|
modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
||||||
|
}
|
||||||
|
finalListener.onResponse(modelBuilder.build());
|
||||||
|
},
|
||||||
|
failure -> {
|
||||||
|
// total feature importance is not necessary for a model to be valid
|
||||||
|
// we shouldn't fail if it is not found
|
||||||
|
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
||||||
|
finalListener.onResponse(modelBuilder.build());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
finalListener.onFailure(failure);
|
||||||
|
}
|
||||||
|
));
|
||||||
|
|
||||||
|
},
|
||||||
|
finalListener::onFailure
|
||||||
|
);
|
||||||
|
|
||||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
||||||
.idsQuery()
|
.idsQuery()
|
||||||
.addIds(modelId));
|
.addIds(modelId));
|
||||||
|
@ -483,11 +519,11 @@ public class TrainedModelProvider {
|
||||||
try {
|
try {
|
||||||
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
|
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
|
||||||
} catch (ResourceNotFoundException ex) {
|
} catch (ResourceNotFoundException ex) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
getTrainedModelListener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||||
return;
|
return;
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
listener.onFailure(ex);
|
getTrainedModelListener.onFailure(ex);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -500,22 +536,22 @@ public class TrainedModelProvider {
|
||||||
String compressedString = getDefinitionFromDocs(docs, modelId);
|
String compressedString = getDefinitionFromDocs(docs, modelId);
|
||||||
builder.setDefinitionFromString(compressedString);
|
builder.setDefinitionFromString(compressedString);
|
||||||
} catch (ElasticsearchException elasticsearchException) {
|
} catch (ElasticsearchException elasticsearchException) {
|
||||||
listener.onFailure(elasticsearchException);
|
getTrainedModelListener.onFailure(elasticsearchException);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (ResourceNotFoundException ex) {
|
} catch (ResourceNotFoundException ex) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
getTrainedModelListener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||||
return;
|
return;
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
listener.onFailure(ex);
|
getTrainedModelListener.onFailure(ex);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
listener.onResponse(builder.build());
|
getTrainedModelListener.onResponse(builder);
|
||||||
},
|
},
|
||||||
listener::onFailure
|
getTrainedModelListener::onFailure
|
||||||
);
|
);
|
||||||
|
|
||||||
executeAsyncWithOrigin(client,
|
executeAsyncWithOrigin(client,
|
||||||
|
@ -532,7 +568,10 @@ public class TrainedModelProvider {
|
||||||
* This does no expansion on the ids.
|
* This does no expansion on the ids.
|
||||||
* It assumes that there are fewer than 10k.
|
* It assumes that there are fewer than 10k.
|
||||||
*/
|
*/
|
||||||
public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
|
public void getTrainedModels(Set<String> modelIds,
|
||||||
|
boolean allowNoResources,
|
||||||
|
boolean includeTotalFeatureImportance,
|
||||||
|
final ActionListener<List<TrainedModelConfig>> finalListener) {
|
||||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
|
||||||
|
|
||||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||||
|
@ -541,23 +580,63 @@ public class TrainedModelProvider {
|
||||||
.setQuery(queryBuilder)
|
.setQuery(queryBuilder)
|
||||||
.setSize(modelIds.size())
|
.setSize(modelIds.size())
|
||||||
.request();
|
.request();
|
||||||
List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
|
List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
|
||||||
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
|
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
|
||||||
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
|
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
|
||||||
for(String modelId : modelsAsResource) {
|
for(String modelId : modelsAsResource) {
|
||||||
try {
|
try {
|
||||||
configs.add(loadModelFromResource(modelId, true));
|
configs.add(loadModelFromResource(modelId, true));
|
||||||
} catch (ElasticsearchException ex) {
|
} catch (ElasticsearchException ex) {
|
||||||
listener.onFailure(ex);
|
finalListener.onFailure(ex);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (modelsInIndex.isEmpty()) {
|
if (modelsInIndex.isEmpty()) {
|
||||||
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
finalListener.onResponse(configs.stream()
|
||||||
listener.onResponse(configs);
|
.map(TrainedModelConfig.Builder::build)
|
||||||
|
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
|
||||||
|
modelBuilders -> {
|
||||||
|
if (includeTotalFeatureImportance == false) {
|
||||||
|
finalListener.onResponse(modelBuilders.stream()
|
||||||
|
.map(TrainedModelConfig.Builder::build)
|
||||||
|
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
|
||||||
|
metadata ->
|
||||||
|
finalListener.onResponse(modelBuilders.stream()
|
||||||
|
.map(builder -> {
|
||||||
|
TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
|
||||||
|
if (modelMetadata != null) {
|
||||||
|
builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
})
|
||||||
|
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||||
|
.collect(Collectors.toList())),
|
||||||
|
failure -> {
|
||||||
|
// total feature importance is not necessary for a model to be valid
|
||||||
|
// we shouldn't fail if it is not found
|
||||||
|
if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
||||||
|
finalListener.onResponse(modelBuilders.stream()
|
||||||
|
.map(TrainedModelConfig.Builder::build)
|
||||||
|
.sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
finalListener.onFailure(failure);
|
||||||
|
}
|
||||||
|
));
|
||||||
|
},
|
||||||
|
finalListener::onFailure
|
||||||
|
);
|
||||||
|
|
||||||
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
||||||
searchResponse -> {
|
searchResponse -> {
|
||||||
Set<String> observedIds = new HashSet<>(
|
Set<String> observedIds = new HashSet<>(
|
||||||
|
@ -568,12 +647,12 @@ public class TrainedModelProvider {
|
||||||
try {
|
try {
|
||||||
if (observedIds.contains(searchHit.getId()) == false) {
|
if (observedIds.contains(searchHit.getId()) == false) {
|
||||||
configs.add(
|
configs.add(
|
||||||
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
|
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId())
|
||||||
);
|
);
|
||||||
observedIds.add(searchHit.getId());
|
observedIds.add(searchHit.getId());
|
||||||
}
|
}
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
listener.onFailure(
|
getTrainedModelListener.onFailure(
|
||||||
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
|
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -583,14 +662,13 @@ public class TrainedModelProvider {
|
||||||
// Otherwise, treat it as if it was never expanded to begin with.
|
// Otherwise, treat it as if it was never expanded to begin with.
|
||||||
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
|
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
|
||||||
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
|
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
|
||||||
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Ensure sorted even with the injection of locally resourced models
|
// Ensure sorted even with the injection of locally resourced models
|
||||||
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
getTrainedModelListener.onResponse(configs);
|
||||||
listener.onResponse(configs);
|
|
||||||
},
|
},
|
||||||
listener::onFailure
|
getTrainedModelListener::onFailure
|
||||||
);
|
);
|
||||||
|
|
||||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
|
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
|
||||||
|
@ -639,7 +717,7 @@ public class TrainedModelProvider {
|
||||||
foundResourceIds = new HashSet<>();
|
foundResourceIds = new HashSet<>();
|
||||||
for(String resourceId : matchedResourceIds) {
|
for(String resourceId : matchedResourceIds) {
|
||||||
// Does the model as a resource have all the tags?
|
// Does the model as a resource have all the tags?
|
||||||
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) {
|
||||||
foundResourceIds.add(resourceId);
|
foundResourceIds.add(resourceId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -833,7 +911,7 @@ public class TrainedModelProvider {
|
||||||
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
|
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
|
||||||
}
|
}
|
||||||
|
|
||||||
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
||||||
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
|
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
|
||||||
if (resource == null) {
|
if (resource == null) {
|
||||||
logger.error("[{}] presumed stored as a resource but not found", modelId);
|
logger.error("[{}] presumed stored as a resource but not found", modelId);
|
||||||
|
@ -848,7 +926,7 @@ public class TrainedModelProvider {
|
||||||
if (nullOutDefinition) {
|
if (nullOutDefinition) {
|
||||||
builder.clearDefinition();
|
builder.clearDefinition();
|
||||||
}
|
}
|
||||||
return builder.build();
|
return builder;
|
||||||
} catch (IOException ioEx) {
|
} catch (IOException ioEx) {
|
||||||
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
|
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
|
||||||
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
|
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
@ -56,12 +57,17 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
||||||
if (Strings.isNullOrEmpty(modelId)) {
|
if (Strings.isNullOrEmpty(modelId)) {
|
||||||
modelId = Metadata.ALL;
|
modelId = Metadata.ALL;
|
||||||
}
|
}
|
||||||
boolean includeModelDefinition = restRequest.paramAsBoolean(
|
|
||||||
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
|
|
||||||
false
|
|
||||||
);
|
|
||||||
List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
|
List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
|
||||||
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
|
Set<String> includes = new HashSet<>(
|
||||||
|
asList(
|
||||||
|
restRequest.paramAsStringArray(
|
||||||
|
GetTrainedModelsAction.Request.INCLUDE.getPreferredName(),
|
||||||
|
Strings.EMPTY_ARRAY)));
|
||||||
|
final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ?
|
||||||
|
new GetTrainedModelsAction.Request(modelId,
|
||||||
|
restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false),
|
||||||
|
tags) :
|
||||||
|
new GetTrainedModelsAction.Request(modelId, tags, includes);
|
||||||
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
|
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
|
||||||
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
|
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
|
||||||
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
||||||
|
|
|
@ -437,9 +437,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
// the loading occurred or which models are currently in the cache due to evictions.
|
// the loading occurred or which models are currently in the cache due to evictions.
|
||||||
// Verify that we have at least loaded all three
|
// Verify that we have at least loaded all three
|
||||||
assertBusy(() -> {
|
assertBusy(() -> {
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
|
||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
|
||||||
});
|
});
|
||||||
assertBusy(() -> {
|
assertBusy(() -> {
|
||||||
assertThat(circuitBreaker.getUsed(), equalTo(10L));
|
assertThat(circuitBreaker.getUsed(), equalTo(10L));
|
||||||
|
@ -553,10 +553,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
||||||
doAnswer(invocationOnMock -> {
|
doAnswer(invocationOnMock -> {
|
||||||
@SuppressWarnings("rawtypes")
|
@SuppressWarnings("rawtypes")
|
||||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||||
listener.onResponse(trainedModelConfig);
|
listener.onResponse(trainedModelConfig);
|
||||||
return null;
|
return null;
|
||||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
|
@ -564,20 +564,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
doAnswer(invocationOnMock -> {
|
doAnswer(invocationOnMock -> {
|
||||||
@SuppressWarnings("rawtypes")
|
@SuppressWarnings("rawtypes")
|
||||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||||
return null;
|
return null;
|
||||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||||
} else {
|
} else {
|
||||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||||
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
|
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
|
||||||
doAnswer(invocationOnMock -> {
|
doAnswer(invocationOnMock -> {
|
||||||
@SuppressWarnings("rawtypes")
|
@SuppressWarnings("rawtypes")
|
||||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
|
||||||
listener.onResponse(trainedModelConfig);
|
listener.onResponse(trainedModelConfig);
|
||||||
return null;
|
return null;
|
||||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
|
||||||
doAnswer(invocationOnMock -> {
|
doAnswer(invocationOnMock -> {
|
||||||
@SuppressWarnings("rawtypes")
|
@SuppressWarnings("rawtypes")
|
||||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
||||||
|
|
|
@ -57,14 +57,14 @@ public class TrainedModelProviderTests extends ESTestCase {
|
||||||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
|
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
|
||||||
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||||
trainedModelProvider.getTrainedModel(modelId, true, future);
|
trainedModelProvider.getTrainedModel(modelId, true, false, future);
|
||||||
TrainedModelConfig configWithDefinition = future.actionGet();
|
TrainedModelConfig configWithDefinition = future.actionGet();
|
||||||
|
|
||||||
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
|
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
|
||||||
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
|
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
|
||||||
|
|
||||||
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
|
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
|
||||||
trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
|
trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
|
||||||
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
|
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
|
||||||
|
|
||||||
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
|
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
|
||||||
|
|
|
@ -33,7 +33,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
|
||||||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||||
// Should be OK as we don't make any client calls
|
// Should be OK as we don't make any client calls
|
||||||
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
|
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
|
||||||
TrainedModelConfig config = future.actionGet();
|
TrainedModelConfig config = future.actionGet();
|
||||||
|
|
||||||
config.ensureParsedDefinition(xContentRegistry());
|
config.ensureParsedDefinition(xContentRegistry());
|
||||||
|
|
|
@ -34,11 +34,10 @@
|
||||||
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
|
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
|
||||||
"default":true
|
"default":true
|
||||||
},
|
},
|
||||||
"include_model_definition":{
|
"include":{
|
||||||
"type":"boolean",
|
"type":"string",
|
||||||
"required":false,
|
"required":false,
|
||||||
"description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.",
|
"description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none."
|
||||||
"default":false
|
|
||||||
},
|
},
|
||||||
"decompress_definition":{
|
"decompress_definition":{
|
||||||
"type":"boolean",
|
"type":"boolean",
|
||||||
|
|
|
@ -1,6 +1,24 @@
|
||||||
setup:
|
setup:
|
||||||
- skip:
|
- skip:
|
||||||
features: headers
|
features:
|
||||||
|
- headers
|
||||||
|
- allowed_warnings
|
||||||
|
- do:
|
||||||
|
allowed_warnings:
|
||||||
|
- "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template"
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||||
|
index:
|
||||||
|
id: trained_model_metadata-a-regression-model-0
|
||||||
|
index: .ml-inference-000003
|
||||||
|
body:
|
||||||
|
model_id: "a-regression-model-0"
|
||||||
|
doc_type: "trained_model_metadata"
|
||||||
|
total_feature_importance:
|
||||||
|
- { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
|
||||||
|
- { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
headers:
|
headers:
|
||||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||||
|
@ -548,6 +566,20 @@ setup:
|
||||||
- match: { count: 12 }
|
- match: { count: 12 }
|
||||||
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
|
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
|
||||||
---
|
---
|
||||||
|
"Test get models with include total feature importance":
|
||||||
|
- do:
|
||||||
|
ml.get_trained_models:
|
||||||
|
model_id: "a-regression-model-*"
|
||||||
|
include: "total_feature_importance"
|
||||||
|
- match: { count: 2 }
|
||||||
|
- length: { trained_model_configs: 2 }
|
||||||
|
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
|
||||||
|
- is_true: trained_model_configs.0.metadata.total_feature_importance
|
||||||
|
- length: { trained_model_configs.0.metadata.total_feature_importance: 2 }
|
||||||
|
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
|
||||||
|
- is_false: trained_model_configs.1.metadata.total_feature_importance
|
||||||
|
|
||||||
|
---
|
||||||
"Test delete given unused trained model":
|
"Test delete given unused trained model":
|
||||||
- do:
|
- do:
|
||||||
ml.delete_trained_model:
|
ml.delete_trained_model:
|
||||||
|
@ -824,7 +856,7 @@ setup:
|
||||||
ml.get_trained_models:
|
ml.get_trained_models:
|
||||||
model_id: "a-regression-model-1"
|
model_id: "a-regression-model-1"
|
||||||
for_export: true
|
for_export: true
|
||||||
include_model_definition: true
|
include: "definition"
|
||||||
decompress_definition: false
|
decompress_definition: false
|
||||||
|
|
||||||
- match: { trained_model_configs.0.description: "empty model for tests" }
|
- match: { trained_model_configs.0.description: "empty model for tests" }
|
||||||
|
|
Loading…
Reference in New Issue