[ML][Inference] Fix model pagination with models as resources (#51573) (#51736)

This adds logic to handle paging problems when the ID pattern + tags reference models stored as resources. 

Most of the complexity comes from the issue where a model stored as a resource could be at the start, or the end of a page or when we are on the last page.
This commit is contained in:
Benjamin Trent 2020-01-31 07:52:19 -05:00 committed by GitHub
parent dfc9f2330c
commit e372854d43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 232 additions and 26 deletions

View File

@ -28,7 +28,6 @@ import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
@ -74,10 +73,10 @@ import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@ -382,19 +381,34 @@ public class TrainedModelProvider {
public void expandIds(String idExpression,
boolean allowNoResources,
@Nullable PageParams pageParams,
PageParams pageParams,
Set<String> tags,
ActionListener<Tuple<Long, Set<String>>> idsListener) {
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
Set<String> matchedResourceIds = matchedResourceIds(tokens);
Set<String> foundResourceIds;
if (tags.isEmpty()) {
foundResourceIds = matchedResourceIds;
} else {
foundResourceIds = new HashSet<>();
for(String resourceId : matchedResourceIds) {
// Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
foundResourceIds.add(resourceId);
}
}
}
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
// If there are no resources, there might be no mapping for the id field.
// This makes sure we don't get an error if that happens.
.unmappedType("long"))
.query(buildExpandIdsQuery(tokens, tags));
if (pageParams != null) {
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
}
.query(buildExpandIdsQuery(tokens, tags))
// We "buffer" the from and size to take into account models stored as resources.
// This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
// a page.
.from(Math.max(0, pageParams.getFrom() - foundResourceIds.size()))
.size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size()));
sourceBuilder.trackTotalHits(true)
// we only care about the item id's
.fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
@ -407,17 +421,6 @@ public class TrainedModelProvider {
indicesOptions.expandWildcardsClosed(),
indicesOptions))
.source(sourceBuilder);
Set<String> foundResourceIds = new LinkedHashSet<>();
if (tags.isEmpty()) {
foundResourceIds.addAll(matchedResourceIds(tokens));
} else {
for(String resourceId : matchedResourceIds(tokens)) {
// Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
foundResourceIds.add(resourceId);
}
}
}
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
@ -425,6 +428,7 @@ public class TrainedModelProvider {
ActionListener.<SearchResponse>wrap(
response -> {
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
Set<String> foundFromDocs = new HashSet<>();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSource = hit.getSourceAsMap();
if (docSource == null) {
@ -432,15 +436,17 @@ public class TrainedModelProvider {
}
Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
if (idValue instanceof String) {
foundResourceIds.add(idValue.toString());
foundFromDocs.add(idValue.toString());
}
}
Set<String> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs);
ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
requiredMatches.filterMatchedIds(foundResourceIds);
requiredMatches.filterMatchedIds(allFoundIds);
if (requiredMatches.hasUnmatchedIds()) {
idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
} else {
idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds));
}
},
idsListener::onFailure
@ -448,6 +454,32 @@ public class TrainedModelProvider {
client::search);
}
static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
// If there are no matching resource models, there was no buffering and the models from the docs
// are paginated correctly.
if (foundFromResources.isEmpty()) {
return foundFromDocs;
}
TreeSet<String> allFoundIds = new TreeSet<>(foundFromDocs);
allFoundIds.addAll(foundFromResources);
if (pageParams.getFrom() > 0) {
// not the first page so there will be extra results at the front to remove
int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom());
for (int i = 0; i < numToTrimFromFront; i++) {
allFoundIds.remove(allFoundIds.first());
}
}
// trim down to size removing from the rear
while (allFoundIds.size() > pageParams.getSize()) {
allFoundIds.remove(allFoundIds.last());
}
return allFoundIds;
}
static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
@ -518,7 +550,7 @@ public class TrainedModelProvider {
private Set<String> matchedResourceIds(String[] tokens) {
if (Strings.isAllOrWildcard(tokens)) {
return new HashSet<>(MODELS_STORED_AS_RESOURCE);
return MODELS_STORED_AS_RESOURCE;
}
Set<String> matchedModels = new HashSet<>();
@ -536,7 +568,7 @@ public class TrainedModelProvider {
}
}
}
return matchedModels;
return Collections.unmodifiableSet(matchedModels);
}
private static <T> T handleSearchItem(MultiSearchResponse.Item item,

View File

@ -14,12 +14,16 @@ import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.TreeSet;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
@ -86,6 +90,50 @@ public class TrainedModelProviderTests extends ESTestCase {
});
}
public void testExpandIdsPagination() {
// NOTE: these tests assume that the query pagination results are "buffered"
assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
Collections.emptySet(),
new HashSet<>(Arrays.asList("a", "b", "c"))),
equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
Collections.singleton("a"),
new HashSet<>(Arrays.asList("b", "c", "d"))),
equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
Collections.singleton("a"),
new HashSet<>(Arrays.asList("b", "c", "d"))),
equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
Collections.singleton("c"),
new HashSet<>(Arrays.asList("a", "b"))),
equalTo(new TreeSet<>(Arrays.asList("b"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
Collections.singleton("b"),
new HashSet<>(Arrays.asList("a", "c"))),
equalTo(new TreeSet<>(Arrays.asList("b"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2),
new HashSet<>(Arrays.asList("a", "b")),
new HashSet<>(Arrays.asList("c", "d", "e"))),
equalTo(new TreeSet<>(Arrays.asList("b", "c"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
new HashSet<>(Arrays.asList("a", "b")),
new HashSet<>(Arrays.asList("c", "d", "e"))),
equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3),
new HashSet<>(Arrays.asList("a", "b")),
new HashSet<>(Arrays.asList("c", "d", "e"))),
equalTo(new TreeSet<>(Arrays.asList("c", "d", "e"))));
}
public void testGetModelThatExistsAsResourceButIsMissing() {
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
ElasticsearchException ex = expectThrows(ElasticsearchException.class,

View File

@ -72,6 +72,56 @@ setup:
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: yyy-classification-model
body: >
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["classification", "tag3"],
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "classification",
"classification_labels": ["no", "yes"]
}
}
}
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ml.put_trained_model:
model_id: zzz-classification-model
body: >
{
"description": "empty model for tests",
"input": {"field_names": ["field1", "field2"]},
"tags": ["classification", "tag3"],
"definition": {
"preprocessors": [],
"trained_model": {
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "leaf_value": 1}
],
"target_type": "classification",
"classification_labels": ["no", "yes"]
}
}
}
}
---
"Test get given missing trained model":
@ -102,15 +152,20 @@ setup:
- do:
ml.get_trained_models:
model_id: "*"
- match: { count: 4 }
- match: { count: 6 }
- length: { trained_model_configs: 6 }
- match: { trained_model_configs.0.model_id: "a-classification-model" }
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
- match: { trained_model_configs.2.model_id: "a-regression-model-1" }
- match: { trained_model_configs.3.model_id: "lang_ident_model_1" }
- match: { trained_model_configs.4.model_id: "yyy-classification-model" }
- match: { trained_model_configs.5.model_id: "zzz-classification-model" }
- do:
ml.get_trained_models:
model_id: "a-regression*"
- match: { count: 2 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
@ -119,7 +174,8 @@ setup:
model_id: "*"
from: 0
size: 2
- match: { count: 4 }
- match: { count: 6 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "a-classification-model" }
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
@ -128,8 +184,78 @@ setup:
model_id: "*"
from: 1
size: 1
- match: { count: 4 }
- match: { count: 6 }
- length: { trained_model_configs: 1 }
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
- do:
ml.get_trained_models:
model_id: "*"
from: 2
size: 2
- match: { count: 6 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
- match: { trained_model_configs.1.model_id: "lang_ident_model_1" }
- do:
ml.get_trained_models:
model_id: "*"
from: 3
size: 1
- match: { count: 6 }
- length: { trained_model_configs: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
- do:
ml.get_trained_models:
model_id: "*"
from: 3
size: 2
- match: { count: 6 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
- match: { trained_model_configs.1.model_id: "yyy-classification-model" }
- do:
ml.get_trained_models:
model_id: "*"
from: 4
size: 2
- match: { count: 6 }
- length: { trained_model_configs: 2 }
- match: { trained_model_configs.0.model_id: "yyy-classification-model" }
- match: { trained_model_configs.1.model_id: "zzz-classification-model" }
- do:
ml.get_trained_models:
model_id: "a-*,lang*,zzz*"
allow_no_match: true
from: 3
size: 1
- match: { count: 5 }
- length: { trained_model_configs: 1 }
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
- do:
ml.get_trained_models:
model_id: "a-*,lang*,zzz*"
allow_no_match: true
from: 4
size: 1
- match: { count: 5 }
- length: { trained_model_configs: 1 }
- match: { trained_model_configs.0.model_id: "zzz-classification-model" }
- do:
ml.get_trained_models:
model_id: "a-*,lang*,zzz*"
from: 4
size: 100
- match: { count: 5 }
- length: { trained_model_configs: 1 }
- match: { trained_model_configs.0.model_id: "zzz-classification-model" }
---
"Test get models with tags":
- do: