* [ML] adds new for_export flag to GET _ml/inference API (#57351) Adds a new boolean flag, `for_export` to the `GET _ml/inference/<model_id>` API. This flag is useful for moving models between clusters.
This commit is contained in:
parent
7dbf5baf60
commit
35d5126cea
|
@ -770,6 +770,9 @@ final class MLRequestConverters {
|
|||
if (getTrainedModelsRequest.getTags() != null) {
|
||||
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
|
||||
}
|
||||
if (getTrainedModelsRequest.getForExport() != null) {
|
||||
params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport()));
|
||||
}
|
||||
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
|
||||
request.addParameters(params.asMap());
|
||||
return request;
|
||||
|
|
|
@ -34,6 +34,7 @@ public class GetTrainedModelsRequest implements Validatable {
|
|||
|
||||
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
||||
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
|
||||
public static final String FOR_EXPORT = "for_export";
|
||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||
public static final String TAGS = "tags";
|
||||
|
||||
|
@ -41,6 +42,7 @@ public class GetTrainedModelsRequest implements Validatable {
|
|||
private Boolean allowNoMatch;
|
||||
private Boolean includeDefinition;
|
||||
private Boolean decompressDefinition;
|
||||
private Boolean forExport;
|
||||
private PageParams pageParams;
|
||||
private List<String> tags;
|
||||
|
||||
|
@ -137,6 +139,23 @@ public class GetTrainedModelsRequest implements Validatable {
|
|||
return setTags(Arrays.asList(tags));
|
||||
}
|
||||
|
||||
public Boolean getForExport() {
|
||||
return forExport;
|
||||
}
|
||||
|
||||
/**
|
||||
* Setting this flag to `true` removes certain fields from the model definition on retrieval.
|
||||
*
|
||||
* This is useful when getting the model and wanting to put it in another cluster.
|
||||
*
|
||||
* Default value is false.
|
||||
* @param forExport Boolean value indicating if certain fields should be removed from the mode on GET
|
||||
*/
|
||||
public GetTrainedModelsRequest setForExport(Boolean forExport) {
|
||||
this.forExport = forExport;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ValidationException> validate() {
|
||||
if (ids == null || ids.isEmpty()) {
|
||||
|
@ -155,11 +174,12 @@ public class GetTrainedModelsRequest implements Validatable {
|
|||
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
||||
&& Objects.equals(decompressDefinition, other.decompressDefinition)
|
||||
&& Objects.equals(includeDefinition, other.includeDefinition)
|
||||
&& Objects.equals(forExport, other.forExport)
|
||||
&& Objects.equals(pageParams, other.pageParams);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
|
||||
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3611,7 +3611,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||
.setIncludeDefinition(false) // <3>
|
||||
.setDecompressDefinition(false) // <4>
|
||||
.setAllowNoMatch(true) // <5>
|
||||
.setTags("regression"); // <6>
|
||||
.setTags("regression") // <6>
|
||||
.setForExport(false); // <7>
|
||||
// end::get-trained-models-request
|
||||
request.setTags((List<String>)null);
|
||||
|
||||
|
|
|
@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request]
|
|||
<6> An optional list of tags used to narrow the model search. A Trained Model
|
||||
can have many tags or none. The trained models in the response will
|
||||
contain all the provided tags.
|
||||
<7> Optional boolean value indicating if certain fields should be removed on
|
||||
retrieval. This is useful for getting the trained model in a format that
|
||||
can then be put into another cluster.
|
||||
|
||||
include::../execution.asciidoc[]
|
||||
|
||||
|
|
|
@ -81,6 +81,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size]
|
|||
(Optional, string)
|
||||
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
|
||||
|
||||
`for_export`::
|
||||
(Optional, boolean)
|
||||
Indicates if certain fields should be removed from the model configuration on
|
||||
retrieval. This allows the model to be in an acceptable format to be retrieved
|
||||
and then added to another cluster. Default is false.
|
||||
|
||||
[role="child_attributes"]
|
||||
[[ml-get-inference-results]]
|
||||
==== {api-response-body-title}
|
||||
|
|
|
@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
public static final String NAME = "trained_model_config";
|
||||
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
|
||||
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
|
||||
public static final String FOR_EXPORT = "for_export";
|
||||
|
||||
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
|
||||
|
||||
|
@ -304,13 +305,22 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
// If the model is to be exported for future import to another cluster, these fields are irrelevant.
|
||||
if (params.paramAsBoolean(FOR_EXPORT, false) == false) {
|
||||
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(CREATED_BY.getPreferredName(), createdBy);
|
||||
builder.field(VERSION.getPreferredName(), version.toString());
|
||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||
builder.humanReadableField(
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
|
||||
new ByteSizeValue(estimatedHeapMemory));
|
||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
|
||||
}
|
||||
if (description != null) {
|
||||
builder.field(DESCRIPTION.getPreferredName(), description);
|
||||
}
|
||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||
// We don't store the definition in the same document as the configuration
|
||||
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
||||
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
|
||||
|
@ -327,12 +337,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
|||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||
}
|
||||
builder.field(INPUT.getPreferredName(), input);
|
||||
builder.humanReadableField(
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
|
||||
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
|
||||
new ByteSizeValue(estimatedHeapMemory));
|
||||
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
|
||||
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
|
||||
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
|
||||
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
|
||||
}
|
||||
|
|
|
@ -40,6 +40,7 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
@ -187,6 +188,43 @@ public class TrainedModelIT extends ESRestTestCase {
|
|||
assertThat(response, containsString("\"definition\""));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testExportImportModel() throws IOException {
|
||||
String modelId = "regression_model_to_export";
|
||||
putRegressionModel(modelId);
|
||||
Response getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH + "inference/" + modelId));
|
||||
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
String response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
|
||||
assertThat(response, containsString("\"count\":1"));
|
||||
|
||||
getModel = client().performRequest(new Request("GET",
|
||||
MachineLearning.BASE_PATH +
|
||||
"inference/" + modelId +
|
||||
"?include_model_definition=true&decompress_definition=false&for_export=true"));
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
|
||||
Map<String, Object> exportedModel = entityAsMap(getModel);
|
||||
Map<String, Object> modelDefinition = ((List<Map<String, Object>>)exportedModel.get("trained_model_configs")).get(0);
|
||||
|
||||
String importedModelId = "regression_model_to_import";
|
||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
builder.map(modelDefinition);
|
||||
Request model = new Request("PUT", "_ml/inference/" + importedModelId);
|
||||
model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
|
||||
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
|
||||
}
|
||||
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*"));
|
||||
|
||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||
response = EntityUtils.toString(getModel.getEntity());
|
||||
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
|
||||
assertThat(response, containsString("\"model_id\":\"regression_model_to_import\""));
|
||||
assertThat(response, containsString("\"count\":2"));
|
||||
}
|
||||
|
||||
private void putRegressionModel(String modelId) throws IOException {
|
||||
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
|
||||
|
|
|
@ -74,7 +74,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
|||
|
||||
@Override
|
||||
protected Set<String> responseParams() {
|
||||
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
|
||||
return org.elasticsearch.common.collect.Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT);
|
||||
}
|
||||
|
||||
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
|
||||
|
|
|
@ -62,6 +62,12 @@
|
|||
"required":false,
|
||||
"type":"list",
|
||||
"description":"A comma-separated list of tags that the model must have."
|
||||
},
|
||||
"for_export": {
|
||||
"required": false,
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Omits fields that are illegal to set on model PUT"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -818,3 +818,24 @@ setup:
|
|||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test for_export flag":
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-regression-model-1"
|
||||
for_export: true
|
||||
include_model_definition: true
|
||||
decompress_definition: false
|
||||
|
||||
- match: { trained_model_configs.0.description: "empty model for tests" }
|
||||
- is_true: trained_model_configs.0.compressed_definition
|
||||
- is_true: trained_model_configs.0.input
|
||||
- is_true: trained_model_configs.0.inference_config
|
||||
- is_true: trained_model_configs.0.tags
|
||||
- is_false: trained_model_configs.0.model_id
|
||||
- is_false: trained_model_configs.0.created_by
|
||||
- is_false: trained_model_configs.0.version
|
||||
- is_false: trained_model_configs.0.create_time
|
||||
- is_false: trained_model_configs.0.estimated_heap_memory_usage
|
||||
- is_false: trained_model_configs.0.estimated_operations
|
||||
- is_false: trained_model_configs.0.license_level
|
||||
|
|
Loading…
Reference in New Issue