From 76660a5a4f49eb0e96e3ec321d070c3c15f046ea Mon Sep 17 00:00:00 2001 From: Benjamin Trent <ben.w.trent@gmail.com> Date: Fri, 24 Jan 2020 08:26:58 -0500 Subject: [PATCH] [7.x] [ML][Inference] add tags url param to GET (#51330) (#51404) * [ML][Inference] add tags url param to GET (#51330) Adds a new URL parameter, `tags` to the GET _ml/inference/<model_id> endpoint. This parameter allows the list of models to be further reduced to those who contain all the provided tags. --- .../client/MLRequestConverters.java | 3 ++ .../client/ml/GetTrainedModelsRequest.java | 26 ++++++++++++++ .../client/MLRequestConvertersTests.java | 2 ++ .../MlClientDocumentationIT.java | 4 ++- .../high-level/ml/get-trained-models.asciidoc | 3 ++ .../apis/get-inference-trained-model.asciidoc | 5 ++- docs/reference/ml/ml-shared.asciidoc | 6 ++++ .../ml/action/GetTrainedModelsAction.java | 22 ++++++++++-- .../action/GetTrainedModelsRequestTests.java | 5 ++- .../TransportGetTrainedModelsAction.java | 7 +++- .../TransportGetTrainedModelsStatsAction.java | 7 +++- .../persistence/TrainedModelProvider.java | 26 ++++++++++++-- .../inference/RestGetTrainedModelsAction.java | 5 ++- .../TrainedModelProviderTests.java | 26 ++++++++++++++ .../api/ml.get_trained_models.json | 7 ++++ .../rest-api-spec/test/ml/inference_crud.yml | 35 +++++++++++++++++++ 16 files changed, 177 insertions(+), 12 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 2e077f547e3..bf220d63b3c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -755,6 +755,9 @@ final class MLRequestConverters { params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION, Boolean.toString(getTrainedModelsRequest.getIncludeDefinition())); } + if (getTrainedModelsRequest.getTags() != null) { + params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags())); + } Request request = new Request(HttpGet.METHOD_NAME, endpoint); request.addParameters(params.asMap()); return request; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java index 9234770a97a..d9aeb52d973 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java @@ -22,6 +22,7 @@ package org.elasticsearch.client.ml; import org.elasticsearch.client.Validatable; import org.elasticsearch.client.ValidationException; import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.client.ml.inference.TrainedModelConfig; import org.elasticsearch.common.Nullable; import java.util.Arrays; @@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable { public static final String ALLOW_NO_MATCH = "allow_no_match"; public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition"; public static final String DECOMPRESS_DEFINITION = "decompress_definition"; + public static final String TAGS = "tags"; private final List<String> ids; private Boolean allowNoMatch; private Boolean includeDefinition; private Boolean decompressDefinition; private PageParams pageParams; + private List<String> tags; /** * Helper method to create a request that will get ALL TrainedModelConfigs @@ -111,6 +114,29 @@ public class GetTrainedModelsRequest implements Validatable { return this; } + public List<String> getTags() { + return tags; + } + + /** + * The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}. + * + * The models returned will match ALL tags supplied. + * If none are provided, only the provided ids are used to find models + * @param tags The tags to match when finding models + */ + public GetTrainedModelsRequest setTags(List<String> tags) { + this.tags = tags; + return this; + } + + /** + * See {@link GetTrainedModelsRequest#setTags(List)} + */ + public GetTrainedModelsRequest setTags(String... tags) { + return setTags(Arrays.asList(tags)); + } + @Override public Optional<ValidationException> validate() { if (ids == null || ids.isEmpty()) { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index f5733ef3a0d..395c62ff837 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -834,6 +834,7 @@ public class MLRequestConvertersTests extends ESTestCase { .setAllowNoMatch(false) .setDecompressDefinition(true) .setIncludeDefinition(false) + .setTags("tag1", "tag2") .setPageParams(new PageParams(100, 300)); Request request = MLRequestConverters.getTrainedModels(getRequest); @@ -845,6 +846,7 @@ public class MLRequestConvertersTests extends ESTestCase { hasEntry("size", "300"), hasEntry("allow_no_match", "false"), hasEntry("decompress_definition", "true"), + hasEntry("tags", "tag1,tag2"), hasEntry("include_model_definition", "false") )); assertNull(request.getEntity()); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 142f1f1f660..5b6fbfe043a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -3587,8 +3587,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase { .setPageParams(new PageParams(0, 1)) // <2> .setIncludeDefinition(false) // <3> .setDecompressDefinition(false) // <4> - .setAllowNoMatch(true); // <5> + .setAllowNoMatch(true) // <5> + .setTags("regression"); // <6> // end::get-trained-models-request + request.setTags((List<String>)null); // tag::get-trained-models-execute GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT); diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc index 4ad9f009126..42cd060d881 100644 --- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc +++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc @@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request] <5> Allow empty response if no Trained Models match the provided ID patterns. If false, an error will be thrown if no Trained Models match the ID patterns. +<6> An optional list of tags used to narrow the model search. A Trained Model + can have many tags or none. The trained models in the response will + contain all the provided tags. include::../execution.asciidoc[] diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc index e972cf040e2..af4e1a2c9ce 100644 --- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc @@ -74,6 +74,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition] (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=size] +`tags`:: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=tags] [[ml-get-inference-response-codes]] ==== {api-response-codes-title} @@ -96,4 +99,4 @@ The following example gets configuration information for all the trained models: -------------------------------------------------- GET _ml/inference/ -------------------------------------------------- -// TEST[skip:TBD] \ No newline at end of file +// TEST[skip:TBD] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index c0219c225ca..893bf0c9e48 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -642,6 +642,12 @@ to `false`. When `true`, only a single model must match the ID patterns provided, otherwise a bad request is returned. end::include-model-definition[] +tag::tags[] +A comma delimited string of tags. A {infer} model can have many tags, or none. +When supplied, only {infer} models that contain all the supplied tags are +returned. +end::tags[] + tag::indices[] An array of index names. Wildcards are supported. For example: `["it_ops_metrics", "server*"]`. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java index b86cfced552..84330f7924a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -33,18 +34,26 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition"); public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + public static final ParseField TAGS = new ParseField("tags"); private final boolean includeModelDefinition; + private final List<String> tags; - public Request(String id, boolean includeModelDefinition) { + public Request(String id, boolean includeModelDefinition, List<String> tags) { setResourceId(id); setAllowNoResources(true); this.includeModelDefinition = includeModelDefinition; + this.tags = tags == null ? Collections.emptyList() : tags; } public Request(StreamInput in) throws IOException { super(in); this.includeModelDefinition = in.readBoolean(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + this.tags = in.readStringList(); + } else { + this.tags = Collections.emptyList(); + } } @Override @@ -56,15 +65,22 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re return includeModelDefinition; } + public List<String> getTags() { + return tags; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeBoolean(includeModelDefinition); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeStringCollection(tags); + } } @Override public int hashCode() { - return Objects.hash(super.hashCode(), includeModelDefinition); + return Objects.hash(super.hashCode(), includeModelDefinition, tags); } @Override @@ -76,7 +92,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re return false; } Request other = (Request) obj; - return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition; + return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java index 85345467df1..7955117e117 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java @@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas @Override protected Request createTestInstance() { - Request request = new Request(randomAlphaOfLength(20), randomBoolean()); + Request request = new Request(randomAlphaOfLength(20), + randomBoolean(), + randomBoolean() ? null : + randomList(10, () -> randomAlphaOfLength(10))); request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); return request; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index 4b467912e1b..1ffc13b8b11 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.Collections; +import java.util.HashSet; import java.util.Set; @@ -70,7 +71,11 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ listener::onFailure ); - provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener); + provider.expandIds(request.getResourceId(), + request.isAllowNoResources(), + request.getPageParams(), + new HashSet<>(request.getTags()), + idExpansionListener); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index a15579b62de..33678ab9089 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; @@ -94,7 +95,11 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction listener::onFailure ); - trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener); + trainedModelProvider.expandIds(request.getResourceId(), + request.isAllowNoResources(), + request.getPageParams(), + Collections.emptySet(), + idsListener); } static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index d63dbf1bc4b..f3cf9f56aab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -70,6 +70,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; @@ -382,6 +383,7 @@ public class TrainedModelProvider { public void expandIds(String idExpression, boolean allowNoResources, @Nullable PageParams pageParams, + Set<String> tags, ActionListener<Tuple<Long, Set<String>>> idsListener) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() @@ -389,7 +391,7 @@ public class TrainedModelProvider { // If there are no resources, there might be no mapping for the id field. // This makes sure we don't get an error if that happens. .unmappedType("long")) - .query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); + .query(buildExpandIdsQuery(tokens, tags)); if (pageParams != null) { sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize()); } @@ -405,13 +407,23 @@ public class TrainedModelProvider { indicesOptions.expandWildcardsClosed(), indicesOptions)) .source(sourceBuilder); + Set<String> foundResourceIds = new LinkedHashSet<>(); + if (tags.isEmpty()) { + foundResourceIds.addAll(matchedResourceIds(tokens)); + } else { + for(String resourceId : matchedResourceIds(tokens)) { + // Does the model as a resource have all the tags? + if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { + foundResourceIds.add(resourceId); + } + } + } executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest, ActionListener.<SearchResponse>wrap( response -> { - Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens)); long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size(); for (SearchHit hit : response.getHits().getHits()) { Map<String, Object> docSource = hit.getSourceAsMap(); @@ -434,7 +446,15 @@ public class TrainedModelProvider { idsListener::onFailure ), client::search); + } + static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) { + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery() + .filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); + for(String tag : tags) { + boolQueryBuilder.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag)); + } + return QueryBuilders.constantScoreQuery(boolQueryBuilder); } TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) { @@ -468,7 +488,7 @@ public class TrainedModelProvider { } } - private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { + private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java index 1aa0fd42350..4d818f974f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java @@ -18,7 +18,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Set; import static org.elasticsearch.rest.RestRequest.Method.GET; @@ -47,7 +49,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler { GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(), false ); - GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition); + List<String> tags = Arrays.asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY)); + GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags); if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 705a8f60dd2..1f90313899d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -9,16 +9,24 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.ConstantScoreQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import java.util.Arrays; + import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.oneOf; import static org.mockito.Mockito.mock; public class TrainedModelProviderTests extends ESTestCase { @@ -60,6 +68,24 @@ public class TrainedModelProviderTests extends ESTestCase { } } + public void testExpandIdsQuery() { + QueryBuilder queryBuilder = TrainedModelProvider.buildExpandIdsQuery(new String[]{"model*", "trained_mode"}, + Arrays.asList("tag1", "tag2")); + assertThat(queryBuilder, is(instanceOf(ConstantScoreQueryBuilder.class))); + + QueryBuilder innerQuery = ((ConstantScoreQueryBuilder)queryBuilder).innerQuery(); + assertThat(innerQuery, is(instanceOf(BoolQueryBuilder.class))); + + ((BoolQueryBuilder)innerQuery).filter().forEach(qb -> { + if (qb instanceof TermQueryBuilder) { + assertThat(((TermQueryBuilder)qb).fieldName(), equalTo(TrainedModelConfig.TAGS.getPreferredName())); + assertThat(((TermQueryBuilder)qb).value(), is(oneOf("tag1", "tag2"))); + return; + } + assertThat(qb, is(instanceOf(BoolQueryBuilder.class))); + }); + } + public void testGetModelThatExistsAsResourceButIsMissing() { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); ElasticsearchException ex = expectThrows(ElasticsearchException.class, diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json index d92c8823b73..7c41285bd1c 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json @@ -46,14 +46,21 @@ "description": "Should the model definition be decompressed into valid JSON or returned in a custom compressed format. Defaults to true." }, "from":{ + "required": false, "type":"int", "description":"skips a number of trained models", "default":0 }, "size":{ + "required": false, "type":"int", "description":"specifies a max number of trained models to get", "default":100 + }, + "tags": { + "required": false, + "type":"list", + "description":"A comma-separated list of tags that the model must have." } } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index f5f9a56bab8..b97caf32948 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -9,6 +9,7 @@ setup: body: > { "description": "empty model for tests", + "tags": ["regression", "tag1"], "input": {"field_names": ["field1", "field2"]}, "definition": { "preprocessors": [], @@ -33,6 +34,7 @@ setup: { "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, + "tags": ["regression", "tag2"], "definition": { "preprocessors": [], "trained_model": { @@ -55,6 +57,7 @@ setup: { "description": "empty model for tests", "input": {"field_names": ["field1", "field2"]}, + "tags": ["classification", "tag2"], "definition": { "preprocessors": [], "trained_model": { @@ -128,6 +131,38 @@ setup: - match: { count: 4 } - match: { trained_model_configs.0.model_id: "a-regression-model-0" } --- +"Test get models with tags": + - do: + ml.get_trained_models: + model_id: "*" + tags: "regression,tag1" + - match: { count: 1 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + + - do: + ml.get_trained_models: + model_id: "a-regression*" + tags: "tag1" + - match: { count: 1 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + + - do: + ml.get_trained_models: + model_id: "*" + tags: "tag2" + - match: { count: 2 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-1" } + + - do: + ml.get_trained_models: + model_id: "*" + tags: "tag2" + from: 1 + size: 1 + - match: { count: 2 } + - match: { trained_model_configs.0.model_id: "a-regression-model-1" } +--- "Test delete given unused trained model": - do: ml.delete_trained_model: