[7.x] [ML] add _cat/ml/trained_models API (#51529) (#51936)

* [ML] add _cat/ml/trained_models API (#51529)

This adds _cat/ml/trained_models.
This commit is contained in:
Benjamin Trent 2020-02-05 08:26:44 -05:00 committed by GitHub
parent b70cbc97aa
commit 79f143907a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 504 additions and 0 deletions

View File

@ -103,6 +103,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
return modelId;
}
public IngestStats getIngestStats() {
return ingestStats;
}
public int getPipelineCount() {
return pipelineCount;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

View File

@ -256,6 +256,7 @@ import org.elasticsearch.xpack.ml.rest.calendar.RestPutCalendarAction;
import org.elasticsearch.xpack.ml.rest.calendar.RestPutCalendarJobAction;
import org.elasticsearch.xpack.ml.rest.cat.RestCatDatafeedsAction;
import org.elasticsearch.xpack.ml.rest.cat.RestCatJobsAction;
import org.elasticsearch.xpack.ml.rest.cat.RestCatTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.datafeeds.RestDeleteDatafeedAction;
import org.elasticsearch.xpack.ml.rest.datafeeds.RestGetDatafeedStatsAction;
import org.elasticsearch.xpack.ml.rest.datafeeds.RestGetDatafeedsAction;
@ -786,6 +787,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
new RestPutTrainedModelAction(restController),
// CAT Handlers
new RestCatJobsAction(restController),
new RestCatTrainedModelsAction(restController),
new RestCatDatafeedsAction(restController)
);
}

View File

@ -177,6 +177,7 @@ public class AnalyticsResultProcessor {
.setCreatedBy(XPackUser.NAME)
.setVersion(Version.CURRENT)
.setCreateTime(createTime)
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
.setTags(Collections.singletonList(analytics.getId()))
.setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config",

View File

@ -0,0 +1,283 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.rest.cat;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.Table;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestResponseListener;
import org.elasticsearch.rest.action.cat.AbstractCatAction;
import org.elasticsearch.rest.action.cat.RestTable;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH;
public class RestCatTrainedModelsAction extends AbstractCatAction {
public RestCatTrainedModelsAction(RestController controller) {
controller.registerHandler(GET, "_cat/ml/trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this);
controller.registerHandler(GET, "_cat/ml/trained_models", this);
}
@Override
public String getName() {
return "cat_ml_get_trained_models_action";
}
@Override
protected RestChannelConsumer doCatRequest(RestRequest restRequest, NodeClient client) {
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
if (Strings.isNullOrEmpty(modelId)) {
modelId = MetaData.ALL;
}
GetTrainedModelsStatsAction.Request statsRequest = new GetTrainedModelsStatsAction.Request(modelId);
GetTrainedModelsAction.Request modelsAction = new GetTrainedModelsAction.Request(modelId, false, null);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
statsRequest.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
modelsAction.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
}
statsRequest.setAllowNoResources(true);
modelsAction.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(),
statsRequest.isAllowNoResources()));
return channel -> {
final ActionListener<Table> listener = ActionListener.notifyOnce(new RestResponseListener<Table>(channel) {
@Override
public RestResponse buildResponse(final Table table) throws Exception {
return RestTable.buildResponse(table, channel);
}
});
client.execute(GetTrainedModelsAction.INSTANCE, modelsAction, ActionListener.wrap(
trainedModels -> {
final List<TrainedModelConfig> trainedModelConfigs = trainedModels.getResources().results();
Set<String> potentialAnalyticsIds = new HashSet<>();
// Analytics Configs are created by the XPackUser
trainedModelConfigs.stream()
.filter(c -> XPackUser.NAME.equals(c.getCreatedBy()))
.forEach(c -> potentialAnalyticsIds.addAll(c.getTags()));
// Find the related DataFrameAnalyticsConfigs
String requestIdPattern = Strings.collectionToDelimitedString(potentialAnalyticsIds, "*,") + "*";
final GroupedActionListener<ActionResponse> groupedListener = createGroupedListener(restRequest,
2,
trainedModels.getResources().results(),
listener);
client.execute(GetTrainedModelsStatsAction.INSTANCE,
statsRequest,
ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure));
GetDataFrameAnalyticsAction.Request dataFrameAnalyticsRequest =
new GetDataFrameAnalyticsAction.Request(requestIdPattern);
dataFrameAnalyticsRequest.setAllowNoResources(true);
dataFrameAnalyticsRequest.setPageParams(new PageParams(0, potentialAnalyticsIds.size()));
client.execute(GetDataFrameAnalyticsAction.INSTANCE,
dataFrameAnalyticsRequest,
ActionListener.wrap(groupedListener::onResponse, groupedListener::onFailure));
},
listener::onFailure
));
};
}
@Override
protected void documentation(StringBuilder sb) {
sb.append("/_cat/ml/trained_models\n");
sb.append("/_cat/ml/trained_models/{model_id}\n");
}
@Override
protected Table getTableWithHeader(RestRequest request) {
Table table = new Table();
table.startHeaders();
// Trained Model Info
table.addCell("id", TableColumnAttributeBuilder.builder().setDescription("the trained model id").build());
table.addCell("created_by", TableColumnAttributeBuilder.builder("who created the model", false)
.setAliases("c", "createdBy")
.setTextAlignment(TableColumnAttributeBuilder.TextAlign.RIGHT)
.build());
table.addCell("heap_size", TableColumnAttributeBuilder.builder()
.setDescription("the estimated heap size to keep the model in memory")
.setAliases("hs","modelHeapSize")
.build());
table.addCell("operations", TableColumnAttributeBuilder.builder()
.setDescription("the estimated number of operations to use the model")
.setAliases("o", "modelOperations")
.build());
table.addCell("license", TableColumnAttributeBuilder.builder("The license level of the model", false)
.setAliases("l")
.build());
table.addCell("create_time", TableColumnAttributeBuilder.builder("The time the model was created")
.setAliases("ct")
.build());
table.addCell("version", TableColumnAttributeBuilder.builder("The version of Elasticsearch when the model was created", false)
.setAliases("v")
.build());
table.addCell("description", TableColumnAttributeBuilder.builder("The model description", false)
.setAliases("d")
.build());
// Trained Model Stats
table.addCell("ingest.pipelines", TableColumnAttributeBuilder.builder("The number of pipelines referencing the model")
.setAliases("ip", "ingestPipelines")
.build());
table.addCell("ingest.count", TableColumnAttributeBuilder.builder("The total number of docs processed by the model", false)
.setAliases("ic", "ingestCount")
.build());
table.addCell("ingest.time", TableColumnAttributeBuilder.builder(
"The total time spent processing docs with this model",
false)
.setAliases("it", "ingestTime")
.build());
table.addCell("ingest.current", TableColumnAttributeBuilder.builder(
"The total documents currently being handled by the model",
false)
.setAliases("icurr", "ingestCurrent")
.build());
table.addCell("ingest.failed", TableColumnAttributeBuilder.builder(
"The total count of failed ingest attempts with this model",
false)
.setAliases("if", "ingestFailed")
.build());
table.addCell("data_frame.id", TableColumnAttributeBuilder.builder(
"The data frame analytics config id that created the model (if still available)")
.setAliases("dfid", "dataFrameAnalytics")
.build());
table.addCell("data_frame.create_time", TableColumnAttributeBuilder.builder(
"The time the data frame analytics config was created",
false)
.setAliases("dft", "dataFrameAnalyticsTime")
.build());
table.addCell("data_frame.source_index", TableColumnAttributeBuilder.builder(
"The source index used to train in the data frame analysis",
false)
.setAliases("dfsi", "dataFrameAnalyticsSrcIndex")
.build());
table.addCell("data_frame.analysis", TableColumnAttributeBuilder.builder(
"The analysis used by the data frame to build the model",
false)
.setAliases("dfa", "dataFrameAnalyticsAnalysis")
.build());
table.endHeaders();
return table;
}
private GroupedActionListener<ActionResponse> createGroupedListener(final RestRequest request,
final int size,
final List<TrainedModelConfig> configs,
final ActionListener<Table> listener) {
return new GroupedActionListener<>(new ActionListener<Collection<ActionResponse>>() {
@Override
public void onResponse(final Collection<ActionResponse> responses) {
GetTrainedModelsStatsAction.Response statsResponse = extractResponse(responses, GetTrainedModelsStatsAction.Response.class);
GetDataFrameAnalyticsAction.Response analytics = extractResponse(responses, GetDataFrameAnalyticsAction.Response.class);
listener.onResponse(buildTable(request,
statsResponse.getResources().results(),
configs,
analytics == null ? Collections.emptyList() : analytics.getResources().results()));
}
@Override
public void onFailure(final Exception e) {
listener.onFailure(e);
}
}, size);
}
private Table buildTable(RestRequest request,
List<GetTrainedModelsStatsAction.Response.TrainedModelStats> stats,
List<TrainedModelConfig> configs,
List<DataFrameAnalyticsConfig> analyticsConfigs) {
Table table = getTableWithHeader(request);
assert configs.size() == stats.size();
Map<String, DataFrameAnalyticsConfig> analyticsMap = analyticsConfigs.stream()
.collect(Collectors.toMap(DataFrameAnalyticsConfig::getId, Function.identity()));
Map<String, GetTrainedModelsStatsAction.Response.TrainedModelStats> statsMap = stats.stream()
.collect(Collectors.toMap(GetTrainedModelsStatsAction.Response.TrainedModelStats::getModelId, Function.identity()));
configs.forEach(config -> {
table.startRow();
// Trained Model Info
table.addCell(config.getModelId());
table.addCell(config.getCreatedBy());
table.addCell(new ByteSizeValue(config.getEstimatedHeapMemory()));
table.addCell(config.getEstimatedOperations());
table.addCell(config.getLicenseLevel());
table.addCell(config.getCreateTime());
table.addCell(config.getVersion().toString());
table.addCell(config.getDescription());
GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = statsMap.get(config.getModelId());
table.addCell(modelStats.getPipelineCount());
boolean hasIngestStats = modelStats != null && modelStats.getIngestStats() != null;
table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCount() : 0);
table.addCell(hasIngestStats ?
TimeValue.timeValueMillis(modelStats.getIngestStats().getTotalStats().getIngestTimeInMillis()) :
TimeValue.timeValueMillis(0));
table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestCurrent() : 0);
table.addCell(hasIngestStats ? modelStats.getIngestStats().getTotalStats().getIngestFailedCount() : 0);
DataFrameAnalyticsConfig dataFrameAnalyticsConfig = config.getTags()
.stream()
.filter(analyticsMap::containsKey)
.map(analyticsMap::get)
.findFirst()
.orElse(null);
table.addCell(dataFrameAnalyticsConfig == null ? "__none__" : dataFrameAnalyticsConfig.getId());
table.addCell(dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getCreateTime());
table.addCell(dataFrameAnalyticsConfig == null ?
null :
Strings.arrayToCommaDelimitedString(dataFrameAnalyticsConfig.getSource().getIndex()));
DataFrameAnalysis analysis = dataFrameAnalyticsConfig == null ? null : dataFrameAnalyticsConfig.getAnalysis();
table.addCell(analysis == null ? null : analysis.getWriteableName());
table.endRow();
});
return table;
}
@SuppressWarnings("unchecked")
private static <A extends ActionResponse> A extractResponse(final Collection<? extends ActionResponse> responses, Class<A> c) {
return (A) responses.stream().filter(c::isInstance).findFirst().get();
}
}

View File

@ -0,0 +1,100 @@
{
"cat.ml.trained_models":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/get-inference-stats.html"
},
"stability":"stable",
"url":{
"paths":[
{
"path":"/_cat/ml/trained_models",
"methods":[
"GET"
]
},
{
"path":"/_cat/ml/trained_models/{model_id}",
"methods":[
"GET"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained models stats to fetch"
}
}
}
]
},
"params":{
"allow_no_match":{
"type":"boolean",
"required":false,
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
"default":true
},
"from":{
"type":"int",
"description":"skips a number of trained models",
"default":0
},
"size":{
"type":"int",
"description":"specifies a max number of trained models to get",
"default":100
},
"bytes":{
"type":"enum",
"description":"The unit in which to display byte values",
"options":[
"b",
"k",
"kb",
"m",
"mb",
"g",
"gb",
"t",
"tb",
"p",
"pb"
]
},
"format":{
"type":"string",
"description":"a short version of the Accept header, e.g. json, yaml"
},
"h":{
"type":"list",
"description":"Comma-separated list of column names to display"
},
"help":{
"type":"boolean",
"description":"Return help information",
"default":false
},
"s":{
"type":"list",
"description":"Comma-separated list of column names or column aliases to sort by"
},
"time":{
"type":"enum",
"description":"The unit in which to display time values",
"options":[
"d (Days)",
"h (Hours)",
"m (Minutes)",
"s (Seconds)",
"ms (Milliseconds)",
"micros (Microseconds)",
"nanos (Nanoseconds)"
]
},
"v":{
"type":"boolean",
"description":"Verbose mode. Display column headers",
"default":false
}
}
}
}

View File

@ -0,0 +1,110 @@
setup:
- skip:
features: headers
- do:
indices.create:
index: index-source
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-regression-model-0
body: >
{
"description": "empty model for tests",
"tags": ["regression", "tag1"],
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "regression"
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: a-regression-model-1
body: >
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "regression"
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_data_frame_analytics:
id: "prepackaged"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {"regression":{
"dependent_variable": "to_predict"
}}
}
---
"Test cat trained models":
- do:
cat.ml.trained_models:
model_id: a-regression-model-0
- match:
$body: |
/ #id heap_size operations create_time ingest.pipelines data_frame.id
^ (a\-regression\-model\-0 \s+ \w+ \s+ \d+ \s+ .*? \s+ \d+ .*? \n)+ $/
- do:
cat.ml.trained_models:
v: true
model_id: a-regression-model-0
- match:
$body: |
/^ id \s+ heap_size \s+ operations \s+ create_time \s+ ingest\.pipelines \s+ data_frame\.id \n
(a\-regression\-model\-0 \s+ \w+ \s+ \d+ \s+ .*? \s+ \d+ \s+ .*? \n)+ $/
- do:
cat.ml.trained_models:
h: id,license,dfid,ip
v: true
- match:
$body: |
/^ id \s+ license \s+ dfid \s+ ip \n
(a\-regression\-model\-0 \s+ \w+ \s+ __none__ \s+ \d+ \n)+
(a\-regression\-model\-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+
(lang_ident_model_1 \s+ \w+ \s+ prepackaged \s+ \d+ \n)+ $/
- do:
cat.ml.trained_models:
model_id: a-regression-model-1
h: id,license,dfid,ip
v: true
- match:
$body: |
/^ id \s+ license \s+ dfid \s+ ip \n
(a\-regression\-model\-1 \s+ \w+ \s+ __none__ \s+ \d+ \n)+ $/