* [ML] adds new for_export flag to GET _ml/inference API (#57351) Adds a new boolean flag, `for_export` to the `GET _ml/inference/<model_id>` API. This flag is useful for moving models between clusters.
This commit is contained in:
parent
7dbf5baf60
commit
35d5126cea
|
@ -770,6 +770,9 @@ final class MLRequestConverters {
|
||||||
if (getTrainedModelsRequest.getTags() != null) {
|
if (getTrainedModelsRequest.getTags() != null) {
|
||||||
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
|
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
|
||||||
}
|
}
|
||||||
|
if (getTrainedModelsRequest.getForExport() != null) {
|
||||||
|
params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport()));
|
||||||
|
}
|
||||||
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
|
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
|
||||||
request.addParameters(params.asMap());
|
request.addParameters(params.asMap());
|
||||||
return request;
|
return request;
|
||||||
|
|
|
@ -34,6 +34,7 @@ public class GetTrainedModelsRequest implements Validatable {
|
||||||
|
|
||||||
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
||||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||||
|
public static final String FOR_EXPORT = "for_export";
|
||||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||||
public static final String TAGS = "tags";
|
public static final String TAGS = "tags";
|
||||||
|
|
||||||
|
@ -41,6 +42,7 @@ public class GetTrainedModelsRequest implements Validatable {
|
||||||
private Boolean allowNoMatch;
|
private Boolean allowNoMatch;
|
||||||
private Boolean includeDefinition;
|
private Boolean includeDefinition;
|
||||||
private Boolean decompressDefinition;
|
private Boolean decompressDefinition;
|
||||||
|
private Boolean forExport;
|
||||||
private PageParams pageParams;
|
private PageParams pageParams;
|
||||||
private List<String> tags;
|
private List<String> tags;
|
||||||
|
|
||||||
|
@ -137,6 +139,23 @@ public class GetTrainedModelsRequest implements Validatable {
|
||||||
return setTags(Arrays.asList(tags));
|
return setTags(Arrays.asList(tags));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Boolean getForExport() {
|
||||||
|
return forExport;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setting this flag to `true` removes certain fields from the model definition on retrieval.
|
||||||
|
*
|
||||||
|
* This is useful when getting the model and wanting to put it in another cluster.
|
||||||
|
*
|
||||||
|
* Default value is false.
|
||||||
|
* @param forExport Boolean value indicating if certain fields should be removed from the mode on GET
|
||||||
|
*/
|
||||||
|
public GetTrainedModelsRequest setForExport(Boolean forExport) {
|
||||||
|
this.forExport = forExport;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Optional<ValidationException> validate() {
|
public Optional<ValidationException> validate() {
|
||||||
if (ids == null || ids.isEmpty()) {
|
if (ids == null || ids.isEmpty()) {
|
||||||
|
@ -155,11 +174,12 @@ public class GetTrainedModelsRequest implements Validatable {
|
||||||
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
||||||
&& Objects.equals(decompressDefinition, other.decompressDefinition)
|
&& Objects.equals(decompressDefinition, other.decompressDefinition)
|
||||||
&& Objects.equals(includeDefinition, other.includeDefinition)
|
&& Objects.equals(includeDefinition, other.includeDefinition)
|
||||||
|
&& Objects.equals(forExport, other.forExport)
|
||||||
&& Objects.equals(pageParams, other.pageParams);
|
&& Objects.equals(pageParams, other.pageParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
|
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3611,7 +3611,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
||||||
.setIncludeDefinition(false) // <3>
|
.setIncludeDefinition(false) // <3>
|
||||||
.setDecompressDefinition(false) // <4>
|
.setDecompressDefinition(false) // <4>
|
||||||
.setAllowNoMatch(true) // <5>
|
.setAllowNoMatch(true) // <5>
|
||||||
.setTags("regression"); // <6>
|
.setTags("regression") // <6>
|
||||||
|
.setForExport(false); // <7>
|
||||||
// end::get-trained-models-request
|
// end::get-trained-models-request
|
||||||
request.setTags((List<String>)null);
|
request.setTags((List<String>)null);
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request]
|
||||||
<6> An optional list of tags used to narrow the model search. A Trained Model
|
<6> An optional list of tags used to narrow the model search. A Trained Model
|
||||||
can have many tags or none. The trained models in the response will
|
can have many tags or none. The trained models in the response will
|
||||||
contain all the provided tags.
|
contain all the provided tags.
|
||||||
|
<7> Optional boolean value indicating if certain fields should be removed on
|
||||||
|
retrieval. This is useful for getting the trained model in a format that
|
||||||
|
can then be put into another cluster.
|
||||||
|
|
||||||
include::../execution.asciidoc[]
|
include::../execution.asciidoc[]
|
||||||
|
|
||||||
|
|
|
@ -81,6 +81,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size]
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
|
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
|
||||||
|
|
||||||
|
`for_export`::
|
||||||
|
(Optional, boolean)
|
||||||
|
Indicates if certain fields should be removed from the model configuration on
|
||||||
|
retrieval. This allows the model to be in an acceptable format to be retrieved
|
||||||
|
and then added to another cluster. Default is false.
|
||||||
|
|
||||||
[role="child_attributes"]
|
[role="child_attributes"]
|
||||||
[[ml-get-inference-results]]
|
[[ml-get-inference-results]]
|
||||||
==== {api-response-body-title}
|
==== {api-response-body-title}
|
||||||
|
|
|
@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
public static final String NAME = "trained_model_config";
|
public static final String NAME = "trained_model_config";
|
||||||
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
|
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
|
||||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||||
|
public static final String FOR_EXPORT = "for_export";
|
||||||
|
|
||||||
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
||||||
|
|
||||||
|
@ -304,13 +305,22 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
// If the model is to be exported for future import to another cluster, these fields are irrelevant.
|
||||||
|
if (params.paramAsBoolean(FOR_EXPORT, false) == false) {
|
||||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||||
builder.field(CREATED_BY.getPreferredName(), createdBy);
|
builder.field(CREATED_BY.getPreferredName(), createdBy);
|
||||||
builder.field(VERSION.getPreferredName(), version.toString());
|
builder.field(VERSION.getPreferredName(), version.toString());
|
||||||
|
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||||
|
builder.humanReadableField(
|
||||||
|
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
|
||||||
|
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
|
||||||
|
new ByteSizeValue(estimatedHeapMemory));
|
||||||
|
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||||
|
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
|
||||||
|
}
|
||||||
if (description != null) {
|
if (description != null) {
|
||||||
builder.field(DESCRIPTION.getPreferredName(), description);
|
builder.field(DESCRIPTION.getPreferredName(), description);
|
||||||
}
|
}
|
||||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
|
||||||
// We don't store the definition in the same document as the configuration
|
// We don't store the definition in the same document as the configuration
|
||||||
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
||||||
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
|
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
|
||||||
|
@ -327,12 +337,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||||
}
|
}
|
||||||
builder.field(INPUT.getPreferredName(), input);
|
builder.field(INPUT.getPreferredName(), input);
|
||||||
builder.humanReadableField(
|
|
||||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
|
|
||||||
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
|
|
||||||
new ByteSizeValue(estimatedHeapMemory));
|
|
||||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
|
||||||
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
|
|
||||||
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
|
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
|
||||||
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
|
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,6 +40,7 @@ import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
@ -187,6 +188,43 @@ public class TrainedModelIT extends ESRestTestCase {
|
||||||
assertThat(response, containsString("\"definition\""));
|
assertThat(response, containsString("\"definition\""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testExportImportModel() throws IOException {
|
||||||
|
String modelId = "regression_model_to_export";
|
||||||
|
putRegressionModel(modelId);
|
||||||
|
Response getModel = client().performRequest(new Request("GET",
|
||||||
|
MachineLearning.BASE_PATH + "inference/" + modelId));
|
||||||
|
|
||||||
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
String response = EntityUtils.toString(getModel.getEntity());
|
||||||
|
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
|
||||||
|
assertThat(response, containsString("\"count\":1"));
|
||||||
|
|
||||||
|
getModel = client().performRequest(new Request("GET",
|
||||||
|
MachineLearning.BASE_PATH +
|
||||||
|
"inference/" + modelId +
|
||||||
|
"?include_model_definition=true&decompress_definition=false&for_export=true"));
|
||||||
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
|
Map<String, Object> exportedModel = entityAsMap(getModel);
|
||||||
|
Map<String, Object> modelDefinition = ((List<Map<String, Object>>)exportedModel.get("trained_model_configs")).get(0);
|
||||||
|
|
||||||
|
String importedModelId = "regression_model_to_import";
|
||||||
|
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||||
|
builder.map(modelDefinition);
|
||||||
|
Request model = new Request("PUT", "_ml/inference/" + importedModelId);
|
||||||
|
model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
|
||||||
|
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
}
|
||||||
|
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*"));
|
||||||
|
|
||||||
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
|
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
|
||||||
|
assertThat(response, containsString("\"model_id\":\"regression_model_to_import\""));
|
||||||
|
assertThat(response, containsString("\"count\":2"));
|
||||||
|
}
|
||||||
|
|
||||||
private void putRegressionModel(String modelId) throws IOException {
|
private void putRegressionModel(String modelId) throws IOException {
|
||||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||||
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
|
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
|
||||||
|
|
|
@ -74,7 +74,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Set<String> responseParams() {
|
protected Set<String> responseParams() {
|
||||||
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
|
return org.elasticsearch.common.collect.Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
|
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
|
||||||
|
|
|
@ -62,6 +62,12 @@
|
||||||
"required":false,
|
"required":false,
|
||||||
"type":"list",
|
"type":"list",
|
||||||
"description":"A comma-separated list of tags that the model must have."
|
"description":"A comma-separated list of tags that the model must have."
|
||||||
|
},
|
||||||
|
"for_export": {
|
||||||
|
"required": false,
|
||||||
|
"type": "boolean",
|
||||||
|
"default": false,
|
||||||
|
"description": "Omits fields that are illegal to set on model PUT"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -818,3 +818,24 @@ setup:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
---
|
||||||
|
"Test for_export flag":
|
||||||
|
- do:
|
||||||
|
ml.get_trained_models:
|
||||||
|
model_id: "a-regression-model-1"
|
||||||
|
for_export: true
|
||||||
|
include_model_definition: true
|
||||||
|
decompress_definition: false
|
||||||
|
|
||||||
|
- match: { trained_model_configs.0.description: "empty model for tests" }
|
||||||
|
- is_true: trained_model_configs.0.compressed_definition
|
||||||
|
- is_true: trained_model_configs.0.input
|
||||||
|
- is_true: trained_model_configs.0.inference_config
|
||||||
|
- is_true: trained_model_configs.0.tags
|
||||||
|
- is_false: trained_model_configs.0.model_id
|
||||||
|
- is_false: trained_model_configs.0.created_by
|
||||||
|
- is_false: trained_model_configs.0.version
|
||||||
|
- is_false: trained_model_configs.0.create_time
|
||||||
|
- is_false: trained_model_configs.0.estimated_heap_memory_usage
|
||||||
|
- is_false: trained_model_configs.0.estimated_operations
|
||||||
|
- is_false: trained_model_configs.0.license_level
|
||||||
|
|
Loading…
Reference in New Issue