When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition. These definitions can be large and returning the inflated definition causes undo work on the server and client side. Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
This commit is contained in:
parent
4bce9984e6
commit
2a5c181dda
|
@ -46,6 +46,8 @@ include::../execution.asciidoc[]
|
||||||
==== Response
|
==== Response
|
||||||
|
|
||||||
The returned +{response}+ contains the newly created trained model.
|
The returned +{response}+ contains the newly created trained model.
|
||||||
|
The +{response}+ will omit the model definition as a precaution against
|
||||||
|
streaming large model definitions back to the client.
|
||||||
|
|
||||||
["source","java",subs="attributes,callouts,macros"]
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
|
|
|
@ -279,7 +279,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
|
||||||
// We don't store the definition in the same document as the configuration
|
// We don't store the definition in the same document as the configuration
|
||||||
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
|
||||||
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) {
|
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
|
||||||
builder.field(DEFINITION.getPreferredName(), definition);
|
builder.field(DEFINITION.getPreferredName(), definition);
|
||||||
} else {
|
} else {
|
||||||
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
|
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
|
||||||
|
@ -370,6 +370,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
this.tags = config.getTags();
|
this.tags = config.getTags();
|
||||||
this.metadata = config.getMetadata();
|
this.metadata = config.getMetadata();
|
||||||
this.input = config.getInput();
|
this.input = config.getInput();
|
||||||
|
this.estimatedOperations = config.estimatedOperations;
|
||||||
|
this.estimatedHeapMemory = config.estimatedHeapMemory;
|
||||||
|
this.licenseLevel = config.licenseLevel.description();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder setModelId(String modelId) {
|
public Builder setModelId(String modelId) {
|
||||||
|
|
|
@ -142,21 +142,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
"platinum");
|
"platinum");
|
||||||
|
|
||||||
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
||||||
assertThat(reference.utf8ToString(), containsString("\"definition\""));
|
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
|
||||||
|
|
||||||
reference = XContentHelper.toXContent(config,
|
reference = XContentHelper.toXContent(config,
|
||||||
XContentType.JSON,
|
XContentType.JSON,
|
||||||
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
|
new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
|
||||||
false);
|
false);
|
||||||
assertThat(reference.utf8ToString(), not(containsString("definition")));
|
assertThat(reference.utf8ToString(), not(containsString("definition")));
|
||||||
|
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
|
||||||
|
|
||||||
reference = XContentHelper.toXContent(config,
|
reference = XContentHelper.toXContent(config,
|
||||||
XContentType.JSON,
|
XContentType.JSON,
|
||||||
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")),
|
new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")),
|
||||||
false);
|
false);
|
||||||
assertThat(reference.utf8ToString(), not(containsString("\"definition\"")));
|
assertThat(reference.utf8ToString(), containsString("\"definition\""));
|
||||||
assertThat(reference.utf8ToString(), containsString("compressed_definition"));
|
assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
|
||||||
assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
|
public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
|
||||||
|
@ -179,7 +179,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
|
||||||
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
|
||||||
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
|
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
|
||||||
|
|
||||||
objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString());
|
objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition());
|
||||||
|
|
||||||
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
|
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
|
||||||
XContentParser parser = XContentType.JSON
|
XContentParser parser = XContentType.JSON
|
||||||
|
|
|
@ -93,6 +93,7 @@ public class TrainedModelIT extends ESRestTestCase {
|
||||||
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
|
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
|
||||||
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
|
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
|
||||||
assertThat(response, containsString("\"definition\""));
|
assertThat(response, containsString("\"definition\""));
|
||||||
|
assertThat(response, not(containsString("\"compressed_definition\"")));
|
||||||
assertThat(response, containsString("\"count\":1"));
|
assertThat(response, containsString("\"count\":1"));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET",
|
getModel = client().performRequest(new Request("GET",
|
||||||
|
|
|
@ -105,7 +105,10 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
|
||||||
|
|
||||||
ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
|
ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
|
||||||
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
|
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
|
||||||
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
|
bool -> {
|
||||||
|
TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
|
||||||
|
listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
|
||||||
|
},
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
)),
|
)),
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
|
|
|
@ -8,8 +8,14 @@ package org.elasticsearch.xpack.ml.rest.inference;
|
||||||
import org.elasticsearch.client.node.NodeClient;
|
import org.elasticsearch.client.node.NodeClient;
|
||||||
import org.elasticsearch.cluster.metadata.MetaData;
|
import org.elasticsearch.cluster.metadata.MetaData;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.rest.BaseRestHandler;
|
import org.elasticsearch.rest.BaseRestHandler;
|
||||||
|
import org.elasticsearch.rest.BytesRestResponse;
|
||||||
|
import org.elasticsearch.rest.RestChannel;
|
||||||
import org.elasticsearch.rest.RestRequest;
|
import org.elasticsearch.rest.RestRequest;
|
||||||
|
import org.elasticsearch.rest.RestResponse;
|
||||||
import org.elasticsearch.rest.action.RestToXContentListener;
|
import org.elasticsearch.rest.action.RestToXContentListener;
|
||||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||||
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
||||||
|
@ -18,7 +24,9 @@ import org.elasticsearch.xpack.ml.MachineLearning;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
import static java.util.Arrays.asList;
|
||||||
|
@ -35,6 +43,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
||||||
new Route(GET, MachineLearning.BASE_PATH + "inference")));
|
new Route(GET, MachineLearning.BASE_PATH + "inference")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
|
||||||
|
Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
|
||||||
@Override
|
@Override
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return "ml_get_trained_models_action";
|
return "ml_get_trained_models_action";
|
||||||
|
@ -57,7 +67,9 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
||||||
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
|
||||||
}
|
}
|
||||||
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
|
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
|
||||||
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
|
return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
|
||||||
|
request,
|
||||||
|
new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -65,4 +77,23 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
|
||||||
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
|
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
|
||||||
|
private final Map<String, String> defaultToXContentParamValues;
|
||||||
|
|
||||||
|
private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
|
||||||
|
super(channel);
|
||||||
|
this.defaultToXContentParamValues = defaultToXContentParamValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
|
||||||
|
assert response.isFragment() == false; //would be nice if we could make default methods final
|
||||||
|
Map<String, String> params = new HashMap<>(channel.request().params());
|
||||||
|
defaultToXContentParamValues.forEach((k, v) ->
|
||||||
|
params.computeIfAbsent(k, defaultToXContentParamValues::get)
|
||||||
|
);
|
||||||
|
response.toXContent(builder, new ToXContent.MapParams(params));
|
||||||
|
return new BytesRestResponse(getStatus(response), builder);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -460,3 +460,53 @@ setup:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
---
|
||||||
|
"Test put model":
|
||||||
|
- do:
|
||||||
|
ml.put_trained_model:
|
||||||
|
model_id: my-regression-model
|
||||||
|
body: >
|
||||||
|
{
|
||||||
|
"description": "model for tests",
|
||||||
|
"input": {"field_names": ["field1", "field2"]},
|
||||||
|
"definition": {
|
||||||
|
"preprocessors": [],
|
||||||
|
"trained_model": {
|
||||||
|
"ensemble": {
|
||||||
|
"target_type": "regression",
|
||||||
|
"trained_models": [
|
||||||
|
{
|
||||||
|
"tree": {
|
||||||
|
"feature_names": ["field1", "field2"],
|
||||||
|
"tree_structure": [
|
||||||
|
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
|
||||||
|
{"node_index": 1, "leaf_value": 0},
|
||||||
|
{"node_index": 2, "leaf_value": 1}
|
||||||
|
],
|
||||||
|
"target_type": "regression"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tree": {
|
||||||
|
"feature_names": ["field1", "field2"],
|
||||||
|
"tree_structure": [
|
||||||
|
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
|
||||||
|
{"node_index": 1, "leaf_value": 0},
|
||||||
|
{"node_index": 2, "leaf_value": 1}
|
||||||
|
],
|
||||||
|
"target_type": "regression"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
- match: { model_id: my-regression-model }
|
||||||
|
- match: { estimated_operations: 6 }
|
||||||
|
- is_false: definition
|
||||||
|
- is_false: compressed_definition
|
||||||
|
- is_true: license_level
|
||||||
|
- is_true: create_time
|
||||||
|
- is_true: version
|
||||||
|
- is_true: estimated_heap_memory_usage_bytes
|
||||||
|
|
Loading…
Reference in New Issue