From e372854d43b51a3aae05d8d192b2b4872aaa550b Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 31 Jan 2020 07:52:19 -0500 Subject: [PATCH] [ML][Inference] Fix model pagination with models as resources (#51573) (#51736) This adds logic to handle paging problems when the ID pattern + tags reference models stored as resources. Most of the complexity comes from the issue where a model stored as a resource could be at the start, or the end of a page or when we are on the last page. --- .../persistence/TrainedModelProvider.java | 78 ++++++++--- .../TrainedModelProviderTests.java | 48 +++++++ .../rest-api-spec/test/ml/inference_crud.yml | 132 +++++++++++++++++- 3 files changed, 232 insertions(+), 26 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index f3cf9f56aab..7c6dee60f2e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -28,7 +28,6 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.CheckedBiFunction; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -74,10 +73,10 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.TreeSet; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -382,19 +381,34 @@ public class TrainedModelProvider { public void expandIds(String idExpression, boolean allowNoResources, - @Nullable PageParams pageParams, + PageParams pageParams, Set tags, ActionListener>> idsListener) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + Set matchedResourceIds = matchedResourceIds(tokens); + Set foundResourceIds; + if (tags.isEmpty()) { + foundResourceIds = matchedResourceIds; + } else { + foundResourceIds = new HashSet<>(); + for(String resourceId : matchedResourceIds) { + // Does the model as a resource have all the tags? + if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { + foundResourceIds.add(resourceId); + } + } + } SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) // If there are no resources, there might be no mapping for the id field. // This makes sure we don't get an error if that happens. .unmappedType("long")) - .query(buildExpandIdsQuery(tokens, tags)); - if (pageParams != null) { - sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize()); - } + .query(buildExpandIdsQuery(tokens, tags)) + // We "buffer" the from and size to take into account models stored as resources. + // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of + // a page. + .from(Math.max(0, pageParams.getFrom() - foundResourceIds.size())) + .size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size())); sourceBuilder.trackTotalHits(true) // we only care about the item id's .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); @@ -407,17 +421,6 @@ public class TrainedModelProvider { indicesOptions.expandWildcardsClosed(), indicesOptions)) .source(sourceBuilder); - Set foundResourceIds = new LinkedHashSet<>(); - if (tags.isEmpty()) { - foundResourceIds.addAll(matchedResourceIds(tokens)); - } else { - for(String resourceId : matchedResourceIds(tokens)) { - // Does the model as a resource have all the tags? - if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { - foundResourceIds.add(resourceId); - } - } - } executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, @@ -425,6 +428,7 @@ public class TrainedModelProvider { ActionListener.wrap( response -> { long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size(); + Set foundFromDocs = new HashSet<>(); for (SearchHit hit : response.getHits().getHits()) { Map docSource = hit.getSourceAsMap(); if (docSource == null) { @@ -432,15 +436,17 @@ public class TrainedModelProvider { } Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); if (idValue instanceof String) { - foundResourceIds.add(idValue.toString()); + foundFromDocs.add(idValue.toString()); } } + Set allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs); ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); - requiredMatches.filterMatchedIds(foundResourceIds); + requiredMatches.filterMatchedIds(allFoundIds); if (requiredMatches.hasUnmatchedIds()) { idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); } else { - idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); + + idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds)); } }, idsListener::onFailure @@ -448,6 +454,32 @@ public class TrainedModelProvider { client::search); } + static Set collectIds(PageParams pageParams, Set foundFromResources, Set foundFromDocs) { + // If there are no matching resource models, there was no buffering and the models from the docs + // are paginated correctly. + if (foundFromResources.isEmpty()) { + return foundFromDocs; + } + + TreeSet allFoundIds = new TreeSet<>(foundFromDocs); + allFoundIds.addAll(foundFromResources); + + if (pageParams.getFrom() > 0) { + // not the first page so there will be extra results at the front to remove + int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom()); + for (int i = 0; i < numToTrimFromFront; i++) { + allFoundIds.remove(allFoundIds.first()); + } + } + + // trim down to size removing from the rear + while (allFoundIds.size() > pageParams.getSize()) { + allFoundIds.remove(allFoundIds.last()); + } + + return allFoundIds; + } + static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection tags) { BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery() .filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); @@ -518,7 +550,7 @@ public class TrainedModelProvider { private Set matchedResourceIds(String[] tokens) { if (Strings.isAllOrWildcard(tokens)) { - return new HashSet<>(MODELS_STORED_AS_RESOURCE); + return MODELS_STORED_AS_RESOURCE; } Set matchedModels = new HashSet<>(); @@ -536,7 +568,7 @@ public class TrainedModelProvider { } } } - return matchedModels; + return Collections.unmodifiableSet(matchedModels); } private static T handleSearchItem(MultiSearchResponse.Item item, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 1f90313899d..aee4c43f227 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -14,12 +14,16 @@ import org.elasticsearch.index.query.ConstantScoreQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -86,6 +90,50 @@ public class TrainedModelProviderTests extends ESTestCase { }); } + public void testExpandIdsPagination() { + // NOTE: these tests assume that the query pagination results are "buffered" + + assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), + Collections.emptySet(), + new HashSet<>(Arrays.asList("a", "b", "c"))), + equalTo(new TreeSet<>(Arrays.asList("a", "b", "c")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), + Collections.singleton("a"), + new HashSet<>(Arrays.asList("b", "c", "d"))), + equalTo(new TreeSet<>(Arrays.asList("a", "b", "c")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3), + Collections.singleton("a"), + new HashSet<>(Arrays.asList("b", "c", "d"))), + equalTo(new TreeSet<>(Arrays.asList("b", "c", "d")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1), + Collections.singleton("c"), + new HashSet<>(Arrays.asList("a", "b"))), + equalTo(new TreeSet<>(Arrays.asList("b")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1), + Collections.singleton("b"), + new HashSet<>(Arrays.asList("a", "c"))), + equalTo(new TreeSet<>(Arrays.asList("b")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2), + new HashSet<>(Arrays.asList("a", "b")), + new HashSet<>(Arrays.asList("c", "d", "e"))), + equalTo(new TreeSet<>(Arrays.asList("b", "c")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3), + new HashSet<>(Arrays.asList("a", "b")), + new HashSet<>(Arrays.asList("c", "d", "e"))), + equalTo(new TreeSet<>(Arrays.asList("b", "c", "d")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3), + new HashSet<>(Arrays.asList("a", "b")), + new HashSet<>(Arrays.asList("c", "d", "e"))), + equalTo(new TreeSet<>(Arrays.asList("c", "d", "e")))); + } + public void testGetModelThatExistsAsResourceButIsMissing() { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); ElasticsearchException ex = expectThrows(ElasticsearchException.class, diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index d7cbc9825b7..0c9fbb350bb 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -72,6 +72,56 @@ setup: } } } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: yyy-classification-model + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "tags": ["classification", "tag3"], + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "classification", + "classification_labels": ["no", "yes"] + } + } + } + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: zzz-classification-model + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "tags": ["classification", "tag3"], + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "classification", + "classification_labels": ["no", "yes"] + } + } + } + } --- "Test get given missing trained model": @@ -102,15 +152,20 @@ setup: - do: ml.get_trained_models: model_id: "*" - - match: { count: 4 } + - match: { count: 6 } + - length: { trained_model_configs: 6 } - match: { trained_model_configs.0.model_id: "a-classification-model" } - match: { trained_model_configs.1.model_id: "a-regression-model-0" } - match: { trained_model_configs.2.model_id: "a-regression-model-1" } + - match: { trained_model_configs.3.model_id: "lang_ident_model_1" } + - match: { trained_model_configs.4.model_id: "yyy-classification-model" } + - match: { trained_model_configs.5.model_id: "zzz-classification-model" } - do: ml.get_trained_models: model_id: "a-regression*" - match: { count: 2 } + - length: { trained_model_configs: 2 } - match: { trained_model_configs.0.model_id: "a-regression-model-0" } - match: { trained_model_configs.1.model_id: "a-regression-model-1" } @@ -119,7 +174,8 @@ setup: model_id: "*" from: 0 size: 2 - - match: { count: 4 } + - match: { count: 6 } + - length: { trained_model_configs: 2 } - match: { trained_model_configs.0.model_id: "a-classification-model" } - match: { trained_model_configs.1.model_id: "a-regression-model-0" } @@ -128,8 +184,78 @@ setup: model_id: "*" from: 1 size: 1 - - match: { count: 4 } + - match: { count: 6 } + - length: { trained_model_configs: 1 } - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 2 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-regression-model-1" } + - match: { trained_model_configs.1.model_id: "lang_ident_model_1" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 3 + size: 1 + - match: { count: 6 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 3 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } + - match: { trained_model_configs.1.model_id: "yyy-classification-model" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 4 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "yyy-classification-model" } + - match: { trained_model_configs.1.model_id: "zzz-classification-model" } + + - do: + ml.get_trained_models: + model_id: "a-*,lang*,zzz*" + allow_no_match: true + from: 3 + size: 1 + - match: { count: 5 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } + + - do: + ml.get_trained_models: + model_id: "a-*,lang*,zzz*" + allow_no_match: true + from: 4 + size: 1 + - match: { count: 5 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "zzz-classification-model" } + + - do: + ml.get_trained_models: + model_id: "a-*,lang*,zzz*" + from: 4 + size: 100 + - match: { count: 5 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "zzz-classification-model" } + --- "Test get models with tags": - do: