mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-25 17:38:44 +00:00
This adds logic to handle paging problems when the ID pattern + tags reference models stored as resources. Most of the complexity comes from the issue where a model stored as a resource could be at the start, or the end of a page or when we are on the last page.
This commit is contained in:
parent
dfc9f2330c
commit
e372854d43
@ -28,7 +28,6 @@ import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.CheckedBiFunction;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
@ -74,10 +73,10 @@ import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
||||
@ -382,19 +381,34 @@ public class TrainedModelProvider {
|
||||
|
||||
public void expandIds(String idExpression,
|
||||
boolean allowNoResources,
|
||||
@Nullable PageParams pageParams,
|
||||
PageParams pageParams,
|
||||
Set<String> tags,
|
||||
ActionListener<Tuple<Long, Set<String>>> idsListener) {
|
||||
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
|
||||
Set<String> matchedResourceIds = matchedResourceIds(tokens);
|
||||
Set<String> foundResourceIds;
|
||||
if (tags.isEmpty()) {
|
||||
foundResourceIds = matchedResourceIds;
|
||||
} else {
|
||||
foundResourceIds = new HashSet<>();
|
||||
for(String resourceId : matchedResourceIds) {
|
||||
// Does the model as a resource have all the tags?
|
||||
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
||||
foundResourceIds.add(resourceId);
|
||||
}
|
||||
}
|
||||
}
|
||||
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
|
||||
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
|
||||
// If there are no resources, there might be no mapping for the id field.
|
||||
// This makes sure we don't get an error if that happens.
|
||||
.unmappedType("long"))
|
||||
.query(buildExpandIdsQuery(tokens, tags));
|
||||
if (pageParams != null) {
|
||||
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
|
||||
}
|
||||
.query(buildExpandIdsQuery(tokens, tags))
|
||||
// We "buffer" the from and size to take into account models stored as resources.
|
||||
// This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
|
||||
// a page.
|
||||
.from(Math.max(0, pageParams.getFrom() - foundResourceIds.size()))
|
||||
.size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size()));
|
||||
sourceBuilder.trackTotalHits(true)
|
||||
// we only care about the item id's
|
||||
.fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
|
||||
@ -407,17 +421,6 @@ public class TrainedModelProvider {
|
||||
indicesOptions.expandWildcardsClosed(),
|
||||
indicesOptions))
|
||||
.source(sourceBuilder);
|
||||
Set<String> foundResourceIds = new LinkedHashSet<>();
|
||||
if (tags.isEmpty()) {
|
||||
foundResourceIds.addAll(matchedResourceIds(tokens));
|
||||
} else {
|
||||
for(String resourceId : matchedResourceIds(tokens)) {
|
||||
// Does the model as a resource have all the tags?
|
||||
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
||||
foundResourceIds.add(resourceId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
||||
ML_ORIGIN,
|
||||
@ -425,6 +428,7 @@ public class TrainedModelProvider {
|
||||
ActionListener.<SearchResponse>wrap(
|
||||
response -> {
|
||||
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
|
||||
Set<String> foundFromDocs = new HashSet<>();
|
||||
for (SearchHit hit : response.getHits().getHits()) {
|
||||
Map<String, Object> docSource = hit.getSourceAsMap();
|
||||
if (docSource == null) {
|
||||
@ -432,15 +436,17 @@ public class TrainedModelProvider {
|
||||
}
|
||||
Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
|
||||
if (idValue instanceof String) {
|
||||
foundResourceIds.add(idValue.toString());
|
||||
foundFromDocs.add(idValue.toString());
|
||||
}
|
||||
}
|
||||
Set<String> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs);
|
||||
ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
|
||||
requiredMatches.filterMatchedIds(foundResourceIds);
|
||||
requiredMatches.filterMatchedIds(allFoundIds);
|
||||
if (requiredMatches.hasUnmatchedIds()) {
|
||||
idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
|
||||
} else {
|
||||
idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
|
||||
|
||||
idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds));
|
||||
}
|
||||
},
|
||||
idsListener::onFailure
|
||||
@ -448,6 +454,32 @@ public class TrainedModelProvider {
|
||||
client::search);
|
||||
}
|
||||
|
||||
static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
|
||||
// If there are no matching resource models, there was no buffering and the models from the docs
|
||||
// are paginated correctly.
|
||||
if (foundFromResources.isEmpty()) {
|
||||
return foundFromDocs;
|
||||
}
|
||||
|
||||
TreeSet<String> allFoundIds = new TreeSet<>(foundFromDocs);
|
||||
allFoundIds.addAll(foundFromResources);
|
||||
|
||||
if (pageParams.getFrom() > 0) {
|
||||
// not the first page so there will be extra results at the front to remove
|
||||
int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom());
|
||||
for (int i = 0; i < numToTrimFromFront; i++) {
|
||||
allFoundIds.remove(allFoundIds.first());
|
||||
}
|
||||
}
|
||||
|
||||
// trim down to size removing from the rear
|
||||
while (allFoundIds.size() > pageParams.getSize()) {
|
||||
allFoundIds.remove(allFoundIds.last());
|
||||
}
|
||||
|
||||
return allFoundIds;
|
||||
}
|
||||
|
||||
static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
|
||||
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
|
||||
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
|
||||
@ -518,7 +550,7 @@ public class TrainedModelProvider {
|
||||
|
||||
private Set<String> matchedResourceIds(String[] tokens) {
|
||||
if (Strings.isAllOrWildcard(tokens)) {
|
||||
return new HashSet<>(MODELS_STORED_AS_RESOURCE);
|
||||
return MODELS_STORED_AS_RESOURCE;
|
||||
}
|
||||
|
||||
Set<String> matchedModels = new HashSet<>();
|
||||
@ -536,7 +568,7 @@ public class TrainedModelProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
return matchedModels;
|
||||
return Collections.unmodifiableSet(matchedModels);
|
||||
}
|
||||
|
||||
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
||||
|
@ -14,12 +14,16 @@ import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.TermQueryBuilder;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
@ -86,6 +90,50 @@ public class TrainedModelProviderTests extends ESTestCase {
|
||||
});
|
||||
}
|
||||
|
||||
public void testExpandIdsPagination() {
|
||||
// NOTE: these tests assume that the query pagination results are "buffered"
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
|
||||
Collections.emptySet(),
|
||||
new HashSet<>(Arrays.asList("a", "b", "c"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
|
||||
Collections.singleton("a"),
|
||||
new HashSet<>(Arrays.asList("b", "c", "d"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
|
||||
Collections.singleton("a"),
|
||||
new HashSet<>(Arrays.asList("b", "c", "d"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
|
||||
Collections.singleton("c"),
|
||||
new HashSet<>(Arrays.asList("a", "b"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("b"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
|
||||
Collections.singleton("b"),
|
||||
new HashSet<>(Arrays.asList("a", "c"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("b"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2),
|
||||
new HashSet<>(Arrays.asList("a", "b")),
|
||||
new HashSet<>(Arrays.asList("c", "d", "e"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("b", "c"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
|
||||
new HashSet<>(Arrays.asList("a", "b")),
|
||||
new HashSet<>(Arrays.asList("c", "d", "e"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
|
||||
|
||||
assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3),
|
||||
new HashSet<>(Arrays.asList("a", "b")),
|
||||
new HashSet<>(Arrays.asList("c", "d", "e"))),
|
||||
equalTo(new TreeSet<>(Arrays.asList("c", "d", "e"))));
|
||||
}
|
||||
|
||||
public void testGetModelThatExistsAsResourceButIsMissing() {
|
||||
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||
|
@ -72,6 +72,56 @@ setup:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: yyy-classification-model
|
||||
body: >
|
||||
{
|
||||
"description": "empty model for tests",
|
||||
"input": {"field_names": ["field1", "field2"]},
|
||||
"tags": ["classification", "tag3"],
|
||||
"definition": {
|
||||
"preprocessors": [],
|
||||
"trained_model": {
|
||||
"tree": {
|
||||
"feature_names": ["field1", "field2"],
|
||||
"tree_structure": [
|
||||
{"node_index": 0, "leaf_value": 1}
|
||||
],
|
||||
"target_type": "classification",
|
||||
"classification_labels": ["no", "yes"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||
ml.put_trained_model:
|
||||
model_id: zzz-classification-model
|
||||
body: >
|
||||
{
|
||||
"description": "empty model for tests",
|
||||
"input": {"field_names": ["field1", "field2"]},
|
||||
"tags": ["classification", "tag3"],
|
||||
"definition": {
|
||||
"preprocessors": [],
|
||||
"trained_model": {
|
||||
"tree": {
|
||||
"feature_names": ["field1", "field2"],
|
||||
"tree_structure": [
|
||||
{"node_index": 0, "leaf_value": 1}
|
||||
],
|
||||
"target_type": "classification",
|
||||
"classification_labels": ["no", "yes"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
---
|
||||
"Test get given missing trained model":
|
||||
|
||||
@ -102,15 +152,20 @@ setup:
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
- match: { count: 4 }
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 6 }
|
||||
- match: { trained_model_configs.0.model_id: "a-classification-model" }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
|
||||
- match: { trained_model_configs.2.model_id: "a-regression-model-1" }
|
||||
- match: { trained_model_configs.3.model_id: "lang_ident_model_1" }
|
||||
- match: { trained_model_configs.4.model_id: "yyy-classification-model" }
|
||||
- match: { trained_model_configs.5.model_id: "zzz-classification-model" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-regression*"
|
||||
- match: { count: 2 }
|
||||
- length: { trained_model_configs: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
|
||||
|
||||
@ -119,7 +174,8 @@ setup:
|
||||
model_id: "*"
|
||||
from: 0
|
||||
size: 2
|
||||
- match: { count: 4 }
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "a-classification-model" }
|
||||
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
|
||||
|
||||
@ -128,8 +184,78 @@ setup:
|
||||
model_id: "*"
|
||||
from: 1
|
||||
size: 1
|
||||
- match: { count: 4 }
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
from: 2
|
||||
size: 2
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
|
||||
- match: { trained_model_configs.1.model_id: "lang_ident_model_1" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
from: 3
|
||||
size: 1
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
from: 3
|
||||
size: 2
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||
- match: { trained_model_configs.1.model_id: "yyy-classification-model" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "*"
|
||||
from: 4
|
||||
size: 2
|
||||
- match: { count: 6 }
|
||||
- length: { trained_model_configs: 2 }
|
||||
- match: { trained_model_configs.0.model_id: "yyy-classification-model" }
|
||||
- match: { trained_model_configs.1.model_id: "zzz-classification-model" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-*,lang*,zzz*"
|
||||
allow_no_match: true
|
||||
from: 3
|
||||
size: 1
|
||||
- match: { count: 5 }
|
||||
- length: { trained_model_configs: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-*,lang*,zzz*"
|
||||
allow_no_match: true
|
||||
from: 4
|
||||
size: 1
|
||||
- match: { count: 5 }
|
||||
- length: { trained_model_configs: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "zzz-classification-model" }
|
||||
|
||||
- do:
|
||||
ml.get_trained_models:
|
||||
model_id: "a-*,lang*,zzz*"
|
||||
from: 4
|
||||
size: 100
|
||||
- match: { count: 5 }
|
||||
- length: { trained_model_configs: 1 }
|
||||
- match: { trained_model_configs.0.model_id: "zzz-classification-model" }
|
||||
|
||||
---
|
||||
"Test get models with tags":
|
||||
- do:
|
||||
|
Loading…
x
Reference in New Issue
Block a user