From 76660a5a4f49eb0e96e3ec321d070c3c15f046ea Mon Sep 17 00:00:00 2001
From: Benjamin Trent <ben.w.trent@gmail.com>
Date: Fri, 24 Jan 2020 08:26:58 -0500
Subject: [PATCH] [7.x] [ML][Inference] add tags url param to GET (#51330)
 (#51404)

* [ML][Inference] add tags url param to GET (#51330)

Adds a new URL parameter, `tags` to the GET _ml/inference/<model_id> endpoint.

This parameter allows the list of models to be further reduced to those who contain all the provided tags.
---
 .../client/MLRequestConverters.java           |  3 ++
 .../client/ml/GetTrainedModelsRequest.java    | 26 ++++++++++++++
 .../client/MLRequestConvertersTests.java      |  2 ++
 .../MlClientDocumentationIT.java              |  4 ++-
 .../high-level/ml/get-trained-models.asciidoc |  3 ++
 .../apis/get-inference-trained-model.asciidoc |  5 ++-
 docs/reference/ml/ml-shared.asciidoc          |  6 ++++
 .../ml/action/GetTrainedModelsAction.java     | 22 ++++++++++--
 .../action/GetTrainedModelsRequestTests.java  |  5 ++-
 .../TransportGetTrainedModelsAction.java      |  7 +++-
 .../TransportGetTrainedModelsStatsAction.java |  7 +++-
 .../persistence/TrainedModelProvider.java     | 26 ++++++++++++--
 .../inference/RestGetTrainedModelsAction.java |  5 ++-
 .../TrainedModelProviderTests.java            | 26 ++++++++++++++
 .../api/ml.get_trained_models.json            |  7 ++++
 .../rest-api-spec/test/ml/inference_crud.yml  | 35 +++++++++++++++++++
 16 files changed, 177 insertions(+), 12 deletions(-)

diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
index 2e077f547e3..bf220d63b3c 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
@@ -755,6 +755,9 @@ final class MLRequestConverters {
             params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
                 Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
         }
+        if (getTrainedModelsRequest.getTags() != null) {
+            params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
+        }
         Request request = new Request(HttpGet.METHOD_NAME, endpoint);
         request.addParameters(params.asMap());
         return request;
diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
index 9234770a97a..d9aeb52d973 100644
--- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
+++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
@@ -22,6 +22,7 @@ package org.elasticsearch.client.ml;
 import org.elasticsearch.client.Validatable;
 import org.elasticsearch.client.ValidationException;
 import org.elasticsearch.client.core.PageParams;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
 import org.elasticsearch.common.Nullable;
 
 import java.util.Arrays;
@@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable {
     public static final String ALLOW_NO_MATCH = "allow_no_match";
     public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
     public static final String DECOMPRESS_DEFINITION = "decompress_definition";
+    public static final String TAGS = "tags";
 
     private final List<String> ids;
     private Boolean allowNoMatch;
     private Boolean includeDefinition;
     private Boolean decompressDefinition;
     private PageParams pageParams;
+    private List<String> tags;
 
     /**
      * Helper method to create a request that will get ALL TrainedModelConfigs
@@ -111,6 +114,29 @@ public class GetTrainedModelsRequest implements Validatable {
         return this;
     }
 
+    public List<String> getTags() {
+        return tags;
+    }
+
+    /**
+     * The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}.
+     *
+     * The models returned will match ALL tags supplied.
+     * If none are provided, only the provided ids are used to find models
+     * @param tags The tags to match when finding models
+     */
+    public GetTrainedModelsRequest setTags(List<String> tags) {
+        this.tags = tags;
+        return this;
+    }
+
+    /**
+     * See {@link GetTrainedModelsRequest#setTags(List)}
+     */
+    public GetTrainedModelsRequest setTags(String... tags) {
+        return setTags(Arrays.asList(tags));
+    }
+
     @Override
     public Optional<ValidationException> validate() {
         if (ids == null || ids.isEmpty()) {
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
index f5733ef3a0d..395c62ff837 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
@@ -834,6 +834,7 @@ public class MLRequestConvertersTests extends ESTestCase {
             .setAllowNoMatch(false)
             .setDecompressDefinition(true)
             .setIncludeDefinition(false)
+            .setTags("tag1", "tag2")
             .setPageParams(new PageParams(100, 300));
 
         Request request = MLRequestConverters.getTrainedModels(getRequest);
@@ -845,6 +846,7 @@ public class MLRequestConvertersTests extends ESTestCase {
                 hasEntry("size", "300"),
                 hasEntry("allow_no_match", "false"),
                 hasEntry("decompress_definition", "true"),
+                hasEntry("tags", "tag1,tag2"),
                 hasEntry("include_model_definition", "false")
             ));
         assertNull(request.getEntity());
diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
index 142f1f1f660..5b6fbfe043a 100644
--- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
+++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
@@ -3587,8 +3587,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setPageParams(new PageParams(0, 1)) // <2>
                 .setIncludeDefinition(false) // <3>
                 .setDecompressDefinition(false) // <4>
-                .setAllowNoMatch(true); // <5>
+                .setAllowNoMatch(true) // <5>
+                .setTags("regression"); // <6>
             // end::get-trained-models-request
+            request.setTags((List<String>)null);
 
             // tag::get-trained-models-execute
             GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);
diff --git a/docs/java-rest/high-level/ml/get-trained-models.asciidoc b/docs/java-rest/high-level/ml/get-trained-models.asciidoc
index 4ad9f009126..42cd060d881 100644
--- a/docs/java-rest/high-level/ml/get-trained-models.asciidoc
+++ b/docs/java-rest/high-level/ml/get-trained-models.asciidoc
@@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request]
 <5> Allow empty response if no Trained Models match the provided ID patterns.
     If false, an error will be thrown if no Trained Models match the
     ID patterns.
+<6> An optional list of tags used to narrow the model search. A Trained Model
+    can have many tags or none. The trained models in the response will
+    contain all the provided tags.
 
 include::../execution.asciidoc[]
 
diff --git a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc
index e972cf040e2..af4e1a2c9ce 100644
--- a/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc
+++ b/docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc
@@ -74,6 +74,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition]
 (Optional, integer) 
 include::{docdir}/ml/ml-shared.asciidoc[tag=size]
 
+`tags`::
+(Optional, string)
+include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
 
 [[ml-get-inference-response-codes]]
 ==== {api-response-codes-title}
@@ -96,4 +99,4 @@ The following example gets configuration information for all the trained models:
 --------------------------------------------------
 GET _ml/inference/
 --------------------------------------------------
-// TEST[skip:TBD]
\ No newline at end of file
+// TEST[skip:TBD]
diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc
index c0219c225ca..893bf0c9e48 100644
--- a/docs/reference/ml/ml-shared.asciidoc
+++ b/docs/reference/ml/ml-shared.asciidoc
@@ -642,6 +642,12 @@ to `false`. When `true`, only a single model must match the ID patterns
 provided, otherwise a bad request is returned.
 end::include-model-definition[]
 
+tag::tags[]
+A comma delimited string of tags. A {infer} model can have many tags, or none.
+When supplied, only {infer} models that contain all the supplied tags are
+returned.
+end::tags[]
+
 tag::indices[]
 An array of index names. Wildcards are supported. For example:
 `["it_ops_metrics", "server*"]`.
diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java
index b86cfced552..84330f7924a 100644
--- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java
+++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java
@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -33,18 +34,26 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
 
         public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
         public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
+        public static final ParseField TAGS = new ParseField("tags");
 
         private final boolean includeModelDefinition;
+        private final List<String> tags;
 
-        public Request(String id, boolean includeModelDefinition) {
+        public Request(String id, boolean includeModelDefinition, List<String> tags) {
             setResourceId(id);
             setAllowNoResources(true);
             this.includeModelDefinition = includeModelDefinition;
+            this.tags = tags == null ? Collections.emptyList() : tags;
         }
 
         public Request(StreamInput in) throws IOException {
             super(in);
             this.includeModelDefinition = in.readBoolean();
+            if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
+                this.tags = in.readStringList();
+            } else {
+                this.tags = Collections.emptyList();
+            }
         }
 
         @Override
@@ -56,15 +65,22 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
             return includeModelDefinition;
         }
 
+        public List<String> getTags() {
+            return tags;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
             out.writeBoolean(includeModelDefinition);
+            if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
+                out.writeStringCollection(tags);
+            }
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(super.hashCode(), includeModelDefinition);
+            return Objects.hash(super.hashCode(), includeModelDefinition, tags);
         }
 
         @Override
@@ -76,7 +92,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
                 return false;
             }
             Request other = (Request) obj;
-            return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
+            return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
         }
     }
 
diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
index 85345467df1..7955117e117 100644
--- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
+++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
@@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
 
     @Override
     protected Request createTestInstance() {
-        Request request = new Request(randomAlphaOfLength(20), randomBoolean());
+        Request request = new Request(randomAlphaOfLength(20),
+            randomBoolean(),
+            randomBoolean() ? null :
+            randomList(10, () -> randomAlphaOfLength(10)));
         request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
         return request;
     }
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java
index 4b467912e1b..1ffc13b8b11 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java
@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.Set;
 
 
@@ -70,7 +71,11 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
             listener::onFailure
         );
 
-        provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
+        provider.expandIds(request.getResourceId(),
+            request.isAllowNoResources(),
+            request.getPageParams(),
+            new HashSet<>(request.getTags()),
+            idExpansionListener);
     }
 
 }
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java
index a15579b62de..33678ab9089 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java
@@ -30,6 +30,7 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
@@ -94,7 +95,11 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
             listener::onFailure
         );
 
-        trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
+        trainedModelProvider.expandIds(request.getResourceId(),
+            request.isAllowNoResources(),
+            request.getPageParams(),
+            Collections.emptySet(),
+            idsListener);
     }
 
     static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
index d63dbf1bc4b..f3cf9f56aab 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
@@ -70,6 +70,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.net.URL;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashSet;
@@ -382,6 +383,7 @@ public class TrainedModelProvider {
     public void expandIds(String idExpression,
                           boolean allowNoResources,
                           @Nullable PageParams pageParams,
+                          Set<String> tags,
                           ActionListener<Tuple<Long, Set<String>>> idsListener) {
         String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
@@ -389,7 +391,7 @@ public class TrainedModelProvider {
                 // If there are no resources, there might be no mapping for the id field.
                 // This makes sure we don't get an error if that happens.
                 .unmappedType("long"))
-            .query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
+            .query(buildExpandIdsQuery(tokens, tags));
         if (pageParams != null) {
             sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
         }
@@ -405,13 +407,23 @@ public class TrainedModelProvider {
                 indicesOptions.expandWildcardsClosed(),
                 indicesOptions))
             .source(sourceBuilder);
+        Set<String> foundResourceIds = new LinkedHashSet<>();
+        if (tags.isEmpty()) {
+            foundResourceIds.addAll(matchedResourceIds(tokens));
+        } else {
+            for(String resourceId : matchedResourceIds(tokens)) {
+                // Does the model as a resource have all the tags?
+                if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
+                    foundResourceIds.add(resourceId);
+                }
+            }
+        }
 
         executeAsyncWithOrigin(client.threadPool().getThreadContext(),
             ML_ORIGIN,
             searchRequest,
             ActionListener.<SearchResponse>wrap(
                 response -> {
-                    Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
                     long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
                     for (SearchHit hit : response.getHits().getHits()) {
                         Map<String, Object> docSource = hit.getSourceAsMap();
@@ -434,7 +446,15 @@ public class TrainedModelProvider {
                 idsListener::onFailure
             ),
             client::search);
+    }
 
+    static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
+        BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
+            .filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
+        for(String tag : tags) {
+            boolQueryBuilder.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag));
+        }
+        return QueryBuilders.constantScoreQuery(boolQueryBuilder);
     }
 
     TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
@@ -468,7 +488,7 @@ public class TrainedModelProvider {
         }
     }
 
-    private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
+    private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
         BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
             .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
 
diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java
index 1aa0fd42350..4d818f974f7 100644
--- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java
+++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java
@@ -18,7 +18,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 import java.util.Set;
 
 import static org.elasticsearch.rest.RestRequest.Method.GET;
@@ -47,7 +49,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
             GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
             false
         );
-        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
+        List<String> tags = Arrays.asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
+        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
         if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
             request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
                 restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java
index 705a8f60dd2..1f90313899d 100644
--- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java
+++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java
@@ -9,16 +9,24 @@ import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 
+import java.util.Arrays;
+
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
+import static org.hamcrest.Matchers.oneOf;
 import static org.mockito.Mockito.mock;
 
 public class TrainedModelProviderTests extends ESTestCase {
@@ -60,6 +68,24 @@ public class TrainedModelProviderTests extends ESTestCase {
         }
     }
 
+    public void testExpandIdsQuery() {
+        QueryBuilder queryBuilder = TrainedModelProvider.buildExpandIdsQuery(new String[]{"model*", "trained_mode"},
+            Arrays.asList("tag1", "tag2"));
+        assertThat(queryBuilder, is(instanceOf(ConstantScoreQueryBuilder.class)));
+
+        QueryBuilder innerQuery = ((ConstantScoreQueryBuilder)queryBuilder).innerQuery();
+        assertThat(innerQuery, is(instanceOf(BoolQueryBuilder.class)));
+
+        ((BoolQueryBuilder)innerQuery).filter().forEach(qb -> {
+            if (qb instanceof TermQueryBuilder) {
+                assertThat(((TermQueryBuilder)qb).fieldName(), equalTo(TrainedModelConfig.TAGS.getPreferredName()));
+                assertThat(((TermQueryBuilder)qb).value(), is(oneOf("tag1", "tag2")));
+                return;
+            }
+            assertThat(qb, is(instanceOf(BoolQueryBuilder.class)));
+        });
+    }
+
     public void testGetModelThatExistsAsResourceButIsMissing() {
         TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
         ElasticsearchException ex = expectThrows(ElasticsearchException.class,
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json
index d92c8823b73..7c41285bd1c 100644
--- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json
@@ -46,14 +46,21 @@
         "description": "Should the model definition be decompressed into valid JSON or returned in a custom compressed format. Defaults to true."
       },
       "from":{
+        "required": false,
         "type":"int",
         "description":"skips a number of trained models",
         "default":0
       },
       "size":{
+        "required": false,
         "type":"int",
         "description":"specifies a max number of trained models to get",
         "default":100
+      },
+      "tags": {
+        "required": false,
+        "type":"list",
+        "description":"A comma-separated list of tags that the model must have."
       }
     }
   }
diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
index f5f9a56bab8..b97caf32948 100644
--- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
+++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml
@@ -9,6 +9,7 @@ setup:
         body: >
           {
             "description": "empty model for tests",
+            "tags": ["regression", "tag1"],
             "input": {"field_names": ["field1", "field2"]},
             "definition": {
                "preprocessors": [],
@@ -33,6 +34,7 @@ setup:
           {
             "description": "empty model for tests",
             "input": {"field_names": ["field1", "field2"]},
+            "tags": ["regression", "tag2"],
             "definition": {
                "preprocessors": [],
                "trained_model": {
@@ -55,6 +57,7 @@ setup:
           {
             "description": "empty model for tests",
             "input": {"field_names": ["field1", "field2"]},
+            "tags": ["classification", "tag2"],
             "definition": {
                "preprocessors": [],
                "trained_model": {
@@ -128,6 +131,38 @@ setup:
   - match: { count: 4 }
   - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
 ---
+"Test get models with tags":
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        tags: "regression,tag1"
+  - match: { count: 1 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "a-regression*"
+        tags: "tag1"
+  - match: { count: 1 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        tags: "tag2"
+  - match: { count: 2 }
+  - match: { trained_model_configs.0.model_id: "a-classification-model" }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        tags: "tag2"
+        from: 1
+        size: 1
+  - match: { count: 2 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-1" }
+---
 "Test delete given unused trained model":
   - do:
       ml.delete_trained_model: