[7.x] [ML][Inference] add tags url param to GET (#51330) (#51404)

* [ML][Inference] add tags url param to GET (#51330)

Adds a new URL parameter, `tags` to the GET _ml/inference/<model_id> endpoint.

This parameter allows the list of models to be further reduced to those who contain all the provided tags.
This commit is contained in:
Benjamin Trent 2020-01-24 08:26:58 -05:00 committed by GitHub
parent d3078c5b40
commit 76660a5a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 177 additions and 12 deletions

View File

@ -755,6 +755,9 @@ final class MLRequestConverters {
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
}
if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;

View File

@ -22,6 +22,7 @@ package org.elasticsearch.client.ml;
import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.Nullable;
import java.util.Arrays;
@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable {
public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags";
private final List<String> ids;
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Boolean decompressDefinition;
private PageParams pageParams;
private List<String> tags;
/**
* Helper method to create a request that will get ALL TrainedModelConfigs
@ -111,6 +114,29 @@ public class GetTrainedModelsRequest implements Validatable {
return this;
}
public List<String> getTags() {
return tags;
}
/**
* The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}.
*
* The models returned will match ALL tags supplied.
* If none are provided, only the provided ids are used to find models
* @param tags The tags to match when finding models
*/
public GetTrainedModelsRequest setTags(List<String> tags) {
this.tags = tags;
return this;
}
/**
* See {@link GetTrainedModelsRequest#setTags(List)}
*/
public GetTrainedModelsRequest setTags(String... tags) {
return setTags(Arrays.asList(tags));
}
@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {

View File

@ -834,6 +834,7 @@ public class MLRequestConvertersTests extends ESTestCase {
.setAllowNoMatch(false)
.setDecompressDefinition(true)
.setIncludeDefinition(false)
.setTags("tag1", "tag2")
.setPageParams(new PageParams(100, 300));
Request request = MLRequestConverters.getTrainedModels(getRequest);
@ -845,6 +846,7 @@ public class MLRequestConvertersTests extends ESTestCase {
hasEntry("size", "300"),
hasEntry("allow_no_match", "false"),
hasEntry("decompress_definition", "true"),
hasEntry("tags", "tag1,tag2"),
hasEntry("include_model_definition", "false")
));
assertNull(request.getEntity());

View File

@ -3587,8 +3587,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
.setPageParams(new PageParams(0, 1)) // <2>
.setIncludeDefinition(false) // <3>
.setDecompressDefinition(false) // <4>
.setAllowNoMatch(true); // <5>
.setAllowNoMatch(true) // <5>
.setTags("regression"); // <6>
// end::get-trained-models-request
request.setTags((List<String>)null);
// tag::get-trained-models-execute
GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);

View File

@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request]
<5> Allow empty response if no Trained Models match the provided ID patterns.
If false, an error will be thrown if no Trained Models match the
ID patterns.
<6> An optional list of tags used to narrow the model search. A Trained Model
can have many tags or none. The trained models in the response will
contain all the provided tags.
include::../execution.asciidoc[]

View File

@ -74,6 +74,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition]
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=size]
`tags`::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
[[ml-get-inference-response-codes]]
==== {api-response-codes-title}
@ -96,4 +99,4 @@ The following example gets configuration information for all the trained models:
--------------------------------------------------
GET _ml/inference/
--------------------------------------------------
// TEST[skip:TBD]
// TEST[skip:TBD]

View File

@ -642,6 +642,12 @@ to `false`. When `true`, only a single model must match the ID patterns
provided, otherwise a bad request is returned.
end::include-model-definition[]
tag::tags[]
A comma delimited string of tags. A {infer} model can have many tags, or none.
When supplied, only {infer} models that contain all the supplied tags are
returned.
end::tags[]
tag::indices[]
An array of index names. Wildcards are supported. For example:
`["it_ops_metrics", "server*"]`.

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -33,18 +34,26 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");
private final boolean includeModelDefinition;
private final List<String> tags;
public Request(String id, boolean includeModelDefinition) {
public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id);
setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
this.tags = tags == null ? Collections.emptyList() : tags;
}
public Request(StreamInput in) throws IOException {
super(in);
this.includeModelDefinition = in.readBoolean();
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
this.tags = in.readStringList();
} else {
this.tags = Collections.emptyList();
}
}
@Override
@ -56,15 +65,22 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
return includeModelDefinition;
}
public List<String> getTags() {
return tags;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(includeModelDefinition);
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
out.writeStringCollection(tags);
}
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition);
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
}
@Override
@ -76,7 +92,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
return false;
}
Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
}
}

View File

@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
Request request = new Request(randomAlphaOfLength(20),
randomBoolean(),
randomBoolean() ? null :
randomList(10, () -> randomAlphaOfLength(10)));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
@ -70,7 +71,11 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
listener::onFailure
);
provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
provider.expandIds(request.getResourceId(),
request.isAllowNoResources(),
request.getPageParams(),
new HashSet<>(request.getTags()),
idExpansionListener);
}
}

View File

@ -30,6 +30,7 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
@ -94,7 +95,11 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
listener::onFailure
);
trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
trainedModelProvider.expandIds(request.getResourceId(),
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
idsListener);
}
static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,

View File

@ -70,6 +70,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
@ -382,6 +383,7 @@ public class TrainedModelProvider {
public void expandIds(String idExpression,
boolean allowNoResources,
@Nullable PageParams pageParams,
Set<String> tags,
ActionListener<Tuple<Long, Set<String>>> idsListener) {
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
@ -389,7 +391,7 @@ public class TrainedModelProvider {
// If there are no resources, there might be no mapping for the id field.
// This makes sure we don't get an error if that happens.
.unmappedType("long"))
.query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
.query(buildExpandIdsQuery(tokens, tags));
if (pageParams != null) {
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
}
@ -405,13 +407,23 @@ public class TrainedModelProvider {
indicesOptions.expandWildcardsClosed(),
indicesOptions))
.source(sourceBuilder);
Set<String> foundResourceIds = new LinkedHashSet<>();
if (tags.isEmpty()) {
foundResourceIds.addAll(matchedResourceIds(tokens));
} else {
for(String resourceId : matchedResourceIds(tokens)) {
// Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
foundResourceIds.add(resourceId);
}
}
}
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(
response -> {
Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSource = hit.getSourceAsMap();
@ -434,7 +446,15 @@ public class TrainedModelProvider {
idsListener::onFailure
),
client::search);
}
static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
for(String tag : tags) {
boolQueryBuilder.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag));
}
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
}
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
@ -468,7 +488,7 @@ public class TrainedModelProvider {
}
}
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));

View File

@ -18,7 +18,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import static org.elasticsearch.rest.RestRequest.Method.GET;
@ -47,7 +49,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
false
);
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
List<String> tags = Arrays.asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));

View File

@ -9,16 +9,24 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.util.Arrays;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.oneOf;
import static org.mockito.Mockito.mock;
public class TrainedModelProviderTests extends ESTestCase {
@ -60,6 +68,24 @@ public class TrainedModelProviderTests extends ESTestCase {
}
}
public void testExpandIdsQuery() {
QueryBuilder queryBuilder = TrainedModelProvider.buildExpandIdsQuery(new String[]{"model*", "trained_mode"},
Arrays.asList("tag1", "tag2"));
assertThat(queryBuilder, is(instanceOf(ConstantScoreQueryBuilder.class)));
QueryBuilder innerQuery = ((ConstantScoreQueryBuilder)queryBuilder).innerQuery();
assertThat(innerQuery, is(instanceOf(BoolQueryBuilder.class)));
((BoolQueryBuilder)innerQuery).filter().forEach(qb -> {
if (qb instanceof TermQueryBuilder) {
assertThat(((TermQueryBuilder)qb).fieldName(), equalTo(TrainedModelConfig.TAGS.getPreferredName()));
assertThat(((TermQueryBuilder)qb).value(), is(oneOf("tag1", "tag2")));
return;
}
assertThat(qb, is(instanceOf(BoolQueryBuilder.class)));
});
}
public void testGetModelThatExistsAsResourceButIsMissing() {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
ElasticsearchException ex = expectThrows(ElasticsearchException.class,

View File

@ -46,14 +46,21 @@
"description": "Should the model definition be decompressed into valid JSON or returned in a custom compressed format. Defaults to true."
},
"from":{
"required": false,
"type":"int",
"description":"skips a number of trained models",
"default":0
},
"size":{
"required": false,
"type":"int",
"description":"specifies a max number of trained models to get",
"default":100
},
"tags": {
"required": false,
"type":"list",
"description":"A comma-separated list of tags that the model must have."
}
}
}

View File

@ -9,6 +9,7 @@ setup:
body: >
{
"description": "empty model for tests",
"tags": ["regression", "tag1"],
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
@ -33,6 +34,7 @@ setup:
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["regression", "tag2"],
"definition": {
"preprocessors": [],
"trained_model": {
@ -55,6 +57,7 @@ setup:
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["classification", "tag2"],
"definition": {
"preprocessors": [],
"trained_model": {
@ -128,6 +131,38 @@ setup:
- match: { count: 4 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
---
"Test get models with tags":
- do:
ml.get_trained_models:
model_id: "*"
tags: "regression,tag1"
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- do:
ml.get_trained_models:
model_id: "a-regression*"
tags: "tag1"
- match: { count: 1 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- do:
ml.get_trained_models:
model_id: "*"
tags: "tag2"
- match: { count: 2 }
- match: { trained_model_configs.0.model_id: "a-classification-model" }
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
- do:
ml.get_trained_models:
model_id: "*"
tags: "tag2"
from: 1
size: 1
- match: { count: 2 }
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
---
"Test delete given unused trained model":
- do:
ml.delete_trained_model: