[7.x] [ML] adds new for_export flag to GET _ml/inference API (#57351) (#57368)

* [ML] adds new for_export flag to GET _ml/inference API (#57351)

Adds a new boolean flag, `for_export` to the `GET _ml/inference/<model_id>` API.

This flag is useful for moving models between clusters.
This commit is contained in:
Benjamin Trent 2020-05-29 14:01:08 -04:00 committed by GitHub
parent 7dbf5baf60
commit 35d5126cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 115 additions and 13 deletions

View File

@ -770,6 +770,9 @@ final class MLRequestConverters {
if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
}
if (getTrainedModelsRequest.getForExport() != null) {
params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;

View File

@ -34,6 +34,7 @@ public class GetTrainedModelsRequest implements Validatable {
public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String FOR_EXPORT = "for_export";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags";
@ -41,6 +42,7 @@ public class GetTrainedModelsRequest implements Validatable {
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Boolean decompressDefinition;
private Boolean forExport;
private PageParams pageParams;
private List<String> tags;
@ -137,6 +139,23 @@ public class GetTrainedModelsRequest implements Validatable {
return setTags(Arrays.asList(tags));
}
public Boolean getForExport() {
return forExport;
}
/**
* Setting this flag to `true` removes certain fields from the model definition on retrieval.
*
* This is useful when getting the model and wanting to put it in another cluster.
*
* Default value is false.
* @param forExport Boolean value indicating if certain fields should be removed from the mode on GET
*/
public GetTrainedModelsRequest setForExport(Boolean forExport) {
this.forExport = forExport;
return this;
}
@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {
@ -155,11 +174,12 @@ public class GetTrainedModelsRequest implements Validatable {
&& Objects.equals(allowNoMatch, other.allowNoMatch)
&& Objects.equals(decompressDefinition, other.decompressDefinition)
&& Objects.equals(includeDefinition, other.includeDefinition)
&& Objects.equals(forExport, other.forExport)
&& Objects.equals(pageParams, other.pageParams);
}
@Override
public int hashCode() {
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
}
}

View File

@ -3611,7 +3611,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setIncludeDefinition(false) // <3>
.setDecompressDefinition(false) // <4>
.setAllowNoMatch(true) // <5>
.setTags("regression"); // <6>
.setTags("regression") // <6>
.setForExport(false); // <7>
// end::get-trained-models-request
request.setTags((List<String>)null);

View File

@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request]
<6> An optional list of tags used to narrow the model search. A Trained Model
can have many tags or none. The trained models in the response will
contain all the provided tags.
<7> Optional boolean value indicating if certain fields should be removed on
retrieval. This is useful for getting the trained model in a format that
can then be put into another cluster.
include::../execution.asciidoc[]

View File

@ -81,6 +81,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size]
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
`for_export`::
(Optional, boolean)
Indicates if certain fields should be removed from the model configuration on
retrieval. This allows the model to be in an acceptable format to be retrieved
and then added to another cluster. Default is false.
[role="child_attributes"]
[[ml-get-inference-results]]
==== {api-response-body-title}

View File

@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_config";
public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String FOR_EXPORT = "for_export";
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
@ -304,13 +305,22 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
// If the model is to be exported for future import to another cluster, these fields are irrelevant.
if (params.paramAsBoolean(FOR_EXPORT, false) == false) {
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(CREATED_BY.getPreferredName(), createdBy);
builder.field(VERSION.getPreferredName(), version.toString());
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
builder.humanReadableField(
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
}
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
@ -327,12 +337,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.field(INPUT.getPreferredName(), input);
builder.humanReadableField(
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
}

View File

@ -40,6 +40,7 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
@ -187,6 +188,43 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(response, containsString("\"definition\""));
}
@SuppressWarnings("unchecked")
public void testExportImportModel() throws IOException {
String modelId = "regression_model_to_export";
putRegressionModel(modelId);
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/" + modelId));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
String response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
assertThat(response, containsString("\"count\":1"));
getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH +
"inference/" + modelId +
"?include_model_definition=true&decompress_definition=false&for_export=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
Map<String, Object> exportedModel = entityAsMap(getModel);
Map<String, Object> modelDefinition = ((List<Map<String, Object>>)exportedModel.get("trained_model_configs")).get(0);
String importedModelId = "regression_model_to_import";
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.map(modelDefinition);
Request model = new Request("PUT", "_ml/inference/" + importedModelId);
model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
}
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
assertThat(response, containsString("\"model_id\":\"regression_model_to_import\""));
assertThat(response, containsString("\"count\":2"));
}
private void putRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()

View File

@ -74,7 +74,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
@Override
protected Set<String> responseParams() {
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
return org.elasticsearch.common.collect.Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT);
}
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {

View File

@ -62,6 +62,12 @@
"required":false,
"type":"list",
"description":"A comma-separated list of tags that the model must have."
},
"for_export": {
"required": false,
"type": "boolean",
"default": false,
"description": "Omits fields that are illegal to set on model PUT"
}
}
}

View File

@ -818,3 +818,24 @@ setup:
}
}
}
---
"Test for_export flag":
- do:
ml.get_trained_models:
model_id: "a-regression-model-1"
for_export: true
include_model_definition: true
decompress_definition: false
- match: { trained_model_configs.0.description: "empty model for tests" }
- is_true: trained_model_configs.0.compressed_definition
- is_true: trained_model_configs.0.input
- is_true: trained_model_configs.0.inference_config
- is_true: trained_model_configs.0.tags
- is_false: trained_model_configs.0.model_id
- is_false: trained_model_configs.0.created_by
- is_false: trained_model_configs.0.version
- is_false: trained_model_configs.0.create_time
- is_false: trained_model_configs.0.estimated_heap_memory_usage
- is_false: trained_model_configs.0.estimated_operations
- is_false: trained_model_configs.0.license_level