[ML][Inference] don't return inflated definition when storing trained models (#52573) (#52580)

When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition.

These definitions can be large and returning the inflated definition causes undo work on the server and client side.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
Benjamin Trent 2020-02-20 19:47:29 -05:00 committed by GitHub
parent 4bce9984e6
commit 2a5c181dda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 99 additions and 9 deletions

View File

@ -46,6 +46,8 @@ include::../execution.asciidoc[]
==== Response ==== Response
The returned +{response}+ contains the newly created trained model. The returned +{response}+ contains the newly created trained model.
The +{response}+ will omit the model definition as a precaution against
streaming large model definitions back to the client.
["source","java",subs="attributes,callouts,macros"] ["source","java",subs="attributes,callouts,macros"]
-------------------------------------------------- --------------------------------------------------

View File

@ -279,7 +279,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration // We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) { if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
builder.field(DEFINITION.getPreferredName(), definition); builder.field(DEFINITION.getPreferredName(), definition);
} else { } else {
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString()); builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
@ -370,6 +370,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.tags = config.getTags(); this.tags = config.getTags();
this.metadata = config.getMetadata(); this.metadata = config.getMetadata();
this.input = config.getInput(); this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
} }
public Builder setModelId(String modelId) { public Builder setModelId(String modelId) {

View File

@ -142,21 +142,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
"platinum"); "platinum");
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("\"definition\"")); assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
reference = XContentHelper.toXContent(config, reference = XContentHelper.toXContent(config,
XContentType.JSON, XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
false); false);
assertThat(reference.utf8ToString(), not(containsString("definition"))); assertThat(reference.utf8ToString(), not(containsString("definition")));
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
reference = XContentHelper.toXContent(config, reference = XContentHelper.toXContent(config,
XContentType.JSON, XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")), new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")),
false); false);
assertThat(reference.utf8ToString(), not(containsString("\"definition\""))); assertThat(reference.utf8ToString(), containsString("\"definition\""));
assertThat(reference.utf8ToString(), containsString("compressed_definition")); assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString()));
} }
public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException { public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
@ -179,7 +179,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2(); Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString()); objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition());
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap); try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
XContentParser parser = XContentType.JSON XContentParser parser = XContentType.JSON

View File

@ -93,6 +93,7 @@ public class TrainedModelIT extends ESRestTestCase {
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\""));
assertThat(response, containsString("\"definition\"")); assertThat(response, containsString("\"definition\""));
assertThat(response, not(containsString("\"compressed_definition\"")));
assertThat(response, containsString("\"count\":1")); assertThat(response, containsString("\"count\":1"));
getModel = client().performRequest(new Request("GET", getModel = client().performRequest(new Request("GET",

View File

@ -105,7 +105,10 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap( ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)), bool -> {
TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
},
listener::onFailure listener::onFailure
)), )),
listener::onFailure listener::onFailure

View File

@ -8,8 +8,14 @@ package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@ -18,7 +24,9 @@ import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
@ -35,6 +43,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
new Route(GET, MachineLearning.BASE_PATH + "inference"))); new Route(GET, MachineLearning.BASE_PATH + "inference")));
} }
private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
@Override @Override
public String getName() { public String getName() {
return "ml_get_trained_models_action"; return "ml_get_trained_models_action";
@ -57,7 +67,9 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
} }
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources())); request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel)); return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
request,
new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
} }
@Override @Override
@ -65,4 +77,23 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION); return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
} }
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
private final Map<String, String> defaultToXContentParamValues;
private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
super(channel);
this.defaultToXContentParamValues = defaultToXContentParamValues;
}
@Override
public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
assert response.isFragment() == false; //would be nice if we could make default methods final
Map<String, String> params = new HashMap<>(channel.request().params());
defaultToXContentParamValues.forEach((k, v) ->
params.computeIfAbsent(k, defaultToXContentParamValues::get)
);
response.toXContent(builder, new ToXContent.MapParams(params));
return new BytesRestResponse(getStatus(response), builder);
}
}
} }

View File

@ -460,3 +460,53 @@ setup:
} }
} }
} }
---
"Test put model":
- do:
ml.put_trained_model:
model_id: my-regression-model
body: >
{
"description": "model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"ensemble": {
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
},
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
}
]
}
}
}
}
- match: { model_id: my-regression-model }
- match: { estimated_operations: 6 }
- is_false: definition
- is_false: compressed_definition
- is_true: license_level
- is_true: create_time
- is_true: version
- is_true: estimated_heap_memory_usage_bytes