diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index db64869c337..bf05815144b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -770,6 +770,9 @@ final class MLRequestConverters { if (getTrainedModelsRequest.getTags() != null) { params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags())); } + if (getTrainedModelsRequest.getForExport() != null) { + params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport())); + } Request request = new Request(HttpGet.METHOD_NAME, endpoint); request.addParameters(params.asMap()); return request; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java index d9aeb52d973..ca0284de84d 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java @@ -34,6 +34,7 @@ public class GetTrainedModelsRequest implements Validatable { public static final String ALLOW_NO_MATCH = "allow_no_match"; public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; + public static final String FOR_EXPORT = "for_export"; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; public static final String TAGS = "tags"; @@ -41,6 +42,7 @@ public class GetTrainedModelsRequest implements Validatable { private Boolean allowNoMatch; private Boolean includeDefinition; private Boolean decompressDefinition; + private Boolean forExport; private PageParams pageParams; private List tags; @@ -137,6 +139,23 @@ public class GetTrainedModelsRequest implements Validatable { return setTags(Arrays.asList(tags)); } + public Boolean getForExport() { + return forExport; + } + + /** + * Setting this flag to `true` removes certain fields from the model definition on retrieval. + * + * This is useful when getting the model and wanting to put it in another cluster. + * + * Default value is false. + * @param forExport Boolean value indicating if certain fields should be removed from the mode on GET + */ + public GetTrainedModelsRequest setForExport(Boolean forExport) { + this.forExport = forExport; + return this; + } + @Override public Optional validate() { if (ids == null || ids.isEmpty()) { @@ -155,11 +174,12 @@ public class GetTrainedModelsRequest implements Validatable { && Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(decompressDefinition, other.decompressDefinition) && Objects.equals(includeDefinition, other.includeDefinition) + && Objects.equals(forExport, other.forExport) && Objects.equals(pageParams, other.pageParams); } @Override public int hashCode() { - return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition); + return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 20f09b74679..f024a7fceaf 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3611,7 +3611,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setIncludeDefinition(false) // <3> .setDecompressDefinition(false) // <4> .setAllowNoMatch(true) // <5> - .setTags("regression"); // <6> + .setTags("regression") // <6> + .setForExport(false); // <7> // end::get-trained-models-request request.setTags((List)null); diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc index 42cd060d881..9d5f964291d 100644 --- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc @@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request] <6> An optional list of tags used to narrow the model search. A Trained Model can have many tags or none. The trained models in the response will contain all the provided tags. +<7> Optional boolean value indicating if certain fields should be removed on + retrieval. This is useful for getting the trained model in a format that + can then be put into another cluster. include::../execution.asciidoc[] diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc index 52e86558807..4c3d1b732b9 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc @@ -81,6 +81,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=tags] +`for_export`:: +(Optional, boolean) +Indicates if certain fields should be removed from the model configuration on +retrieval. This allows the model to be in an acceptable format to be retrieved +and then added to another cluster. Default is false. + [role="child_attributes"] [[ml-get-inference-results]] ==== {api-response-body-title} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index fbc694b7d03..49c2447c236 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String NAME = "trained_model_config"; public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; + public static final String FOR_EXPORT = "for_export"; private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; @@ -304,13 +305,22 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(MODEL_ID.getPreferredName(), modelId); - builder.field(CREATED_BY.getPreferredName(), createdBy); - builder.field(VERSION.getPreferredName(), version.toString()); + // If the model is to be exported for future import to another cluster, these fields are irrelevant. + if (params.paramAsBoolean(FOR_EXPORT, false) == false) { + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(CREATED_BY.getPreferredName(), createdBy); + builder.field(VERSION.getPreferredName(), version.toString()); + builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); + builder.humanReadableField( + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), + ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, + new ByteSizeValue(estimatedHeapMemory)); + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); + builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description()); + } if (description != null) { builder.field(DESCRIPTION.getPreferredName(), description); } - builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); // We don't store the definition in the same document as the configuration if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) { @@ -327,12 +337,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); } builder.field(INPUT.getPreferredName(), input); - builder.humanReadableField( - ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), - ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, - new ByteSizeValue(estimatedHeapMemory)); - builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); - builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description()); if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) { builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index db27ce4aeeb..0500424a289 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.hamcrest.Matchers.containsString; @@ -187,6 +188,43 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(response, containsString("\"definition\"")); } + @SuppressWarnings("unchecked") + public void testExportImportModel() throws IOException { + String modelId = "regression_model_to_export"; + putRegressionModel(modelId); + Response getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + "inference/" + modelId)); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + String response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"regression_model_to_export\"")); + assertThat(response, containsString("\"count\":1")); + + getModel = client().performRequest(new Request("GET", + MachineLearning.BASE_PATH + + "inference/" + modelId + + "?include_model_definition=true&decompress_definition=false&for_export=true")); + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + + Map exportedModel = entityAsMap(getModel); + Map modelDefinition = ((List>)exportedModel.get("trained_model_configs")).get(0); + + String importedModelId = "regression_model_to_import"; + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.map(modelDefinition); + Request model = new Request("PUT", "_ml/inference/" + importedModelId); + model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON)); + assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); + } + getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*")); + + assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); + response = EntityUtils.toString(getModel.getEntity()); + assertThat(response, containsString("\"model_id\":\"regression_model_to_export\"")); + assertThat(response, containsString("\"model_id\":\"regression_model_to_import\"")); + assertThat(response, containsString("\"count\":2")); + } + private void putRegressionModel(String modelId) throws IOException { try(XContentBuilder builder = XContentFactory.jsonBuilder()) { TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 1efdddc468e..7cf6ae7b914 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -74,7 +74,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { @Override protected Set responseParams() { - return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); + return org.elasticsearch.common.collect.Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT); } private static class RestToXContentListenerWithDefaultValues extends RestToXContentListener { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index 3b5b42795bd..168d233c8e3 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -62,6 +62,12 @@ "required":false, "type":"list", "description":"A comma-separated list of tags that the model must have." + }, + "for_export": { + "required": false, + "type": "boolean", + "default": false, + "description": "Omits fields that are illegal to set on model PUT" } } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index e944ed37958..5a437fc4166 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -818,3 +818,24 @@ setup: } } } +--- +"Test for_export flag": + - do: + ml.get_trained_models: + model_id: "a-regression-model-1" + for_export: true + include_model_definition: true + decompress_definition: false + + - match: { trained_model_configs.0.description: "empty model for tests" } + - is_true: trained_model_configs.0.compressed_definition + - is_true: trained_model_configs.0.input + - is_true: trained_model_configs.0.inference_config + - is_true: trained_model_configs.0.tags + - is_false: trained_model_configs.0.model_id + - is_false: trained_model_configs.0.created_by + - is_false: trained_model_configs.0.version + - is_false: trained_model_configs.0.create_time + - is_false: trained_model_configs.0.estimated_heap_memory_usage + - is_false: trained_model_configs.0.estimated_operations + - is_false: trained_model_configs.0.license_level