[ML] partitions model definitions into chunks (#55260) (#55484)

This paves the data layer way so that exceptionally large models are partitioned across multiple documents.

This change means that nodes before 7.8.0 will not be able to use trained inference models created on nodes on or after 7.8.0.

I chose the definition document limit to be 100. This *SHOULD* be plenty for any large model. One of the largest models that I have created so far had the following stats:
~314MB of inflated JSON, ~66MB when compressed, ~177MB of heap.
With the chunking sizes of `16 * 1024 * 1024` its compressed string could be partitioned to 5 documents.
Supporting models 20 times this size (compressed) seems adequate for now.
This commit is contained in:
Benjamin Trent 2020-04-20 16:08:54 -04:00 committed by GitHub
parent fa0373a19f
commit 24d41eb695
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 33 deletions

View File

@ -8,7 +8,13 @@
++++
Creates an {infer} trained model.
+
--
WARNING: Models created in version 7.8.0 are not backwards compatible
with older node versions. If in a mixed cluster environment,
all nodes must be at least 7.8.0 to use a model stored by
a 7.8.0 node.
--
experimental[]

View File

@ -78,7 +78,18 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
}
@Override
protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) {
protected void masterOperation(Request request,
ClusterState state,
ActionListener<Response> listener) {
// 7.8.0 introduced splitting the model definition across multiple documents.
// This means that new models will not be usable on nodes that cannot handle multiple definition documents
if (state.nodes().getMinNodeVersion().before(Version.V_7_8_0)) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Creating a new model requires that all nodes are at least version [{}]",
request.getTrainedModelConfig().getModelId(),
Version.V_7_8_0.toString()));
return;
}
try {
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();

View File

@ -15,7 +15,8 @@ import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.MultiSearchAction;
@ -86,6 +87,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@ -96,6 +98,9 @@ public class TrainedModelProvider {
public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
private static final String MODEL_RESOURCE_FILE_EXT = ".json";
private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024;
private static final int MAX_NUM_DEFINITION_DOCS = 100;
private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS;
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
private final Client client;
@ -139,30 +144,41 @@ public class TrainedModelProvider {
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
ActionListener<Boolean> listener) {
TrainedModelDefinitionDoc trainedModelDefinitionDoc;
List<TrainedModelDefinitionDoc> trainedModelDefinitionDocs = new ArrayList<>();
try {
// TODO should we check length against allowed stream size???
String compressedString = trainedModelConfig.getCompressedDefinition();
trainedModelDefinitionDoc = new TrainedModelDefinitionDoc.Builder()
.setDocNum(0)
.setModelId(trainedModelConfig.getModelId())
.setCompressedString(compressedString)
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
.setDefinitionLength(compressedString.length())
.setTotalDefinitionLength(compressedString.length())
.build();
if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) {
listener.onFailure(
ExceptionsHelper.badRequestException(
"Unable to store model as compressed definition has length [{}] the limit is [{}]",
compressedString.length(),
MAX_COMPRESSED_STRING_SIZE));
return;
}
List<String> chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE);
for(int i = 0; i < chunkedStrings.size(); ++i) {
trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder()
.setDocNum(i)
.setModelId(trainedModelConfig.getModelId())
.setCompressedString(chunkedStrings.get(i))
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
.setDefinitionLength(chunkedStrings.get(i).length())
.setTotalDefinitionLength(compressedString.length())
.build());
}
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.serverError(
"Unexpected IOException while serializing definition for storage for model [" + trainedModelConfig.getModelId() + "]",
ex));
"Unexpected IOException while serializing definition for storage for model [{}]",
ex,
trainedModelConfig.getModelId()));
return;
}
BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
BulkRequestBuilder bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig))
.add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), 0), trainedModelDefinitionDoc))
.request();
.add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig));
trainedModelDefinitionDocs.forEach(defDoc ->
bulkRequest.add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), defDoc.getDocNum()), defDoc)));
ActionListener<Boolean> wrappedListener = ActionListener.wrap(
listener::onResponse,
@ -182,9 +198,8 @@ public class TrainedModelProvider {
ActionListener<BulkResponse> bulkResponseActionListener = ActionListener.wrap(
r -> {
assert r.getItems().length == 2;
assert r.getItems().length == trainedModelDefinitionDocs.size() + 1;
if (r.getItems()[0].isFailed()) {
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model config for inference",
trainedModelConfig.getModelId()),
@ -193,12 +208,18 @@ public class TrainedModelProvider {
wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
return;
}
if (r.getItems()[1].isFailed()) {
if (r.hasFailures()) {
Exception firstFailure = Arrays.stream(r.getItems())
.filter(BulkItemResponse::isFailed)
.map(BulkItemResponse::getFailure)
.map(BulkItemResponse.Failure::getCause)
.findFirst()
.orElse(new Exception("unknown failure"));
logger.error(new ParameterizedMessage(
"[{}] failed to store trained model definition for inference",
trainedModelConfig.getModelId()),
r.getItems()[1].getFailure().getCause());
wrappedListener.onFailure(r.getItems()[1].getFailure().getCause());
firstFailure);
wrappedListener.onFailure(firstFailure);
return;
}
wrappedListener.onResponse(true);
@ -206,7 +227,7 @@ public class TrainedModelProvider {
wrappedListener::onFailure
);
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener);
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
}
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
@ -235,11 +256,20 @@ public class TrainedModelProvider {
if (includeDefinition) {
multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.idsQuery()
.addIds(TrainedModelDefinitionDoc.docId(modelId, 0))))
// use sort to get the last
.boolQuery()
.filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME))))
// There should be AT MOST these many docs. There might be more if definitions have been reindex to newer indices
// If this ends up getting duplicate groups of definition documents, the parsing logic will throw away any doc that
// is in a different index than the first index seen.
.setSize(MAX_NUM_DEFINITION_DOCS)
// First find the latest index
.addSort("_index", SortOrder.DESC)
.setSize(1)
// Then, sort by doc_num
.addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName())
.order(SortOrder.ASC)
// We need this for the search not to fail when there are no mappings yet in the index
.unmappedType("long"))
.request());
}
@ -259,15 +289,18 @@ public class TrainedModelProvider {
if (includeDefinition) {
try {
TrainedModelDefinitionDoc doc = handleSearchItem(multiSearchResponse.getResponses()[1],
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
modelId,
this::parseModelDefinitionDocLenientlyFromSource);
if (doc.getCompressedString().length() != doc.getTotalDefinitionLength()) {
String compressedString = docs.stream()
.map(TrainedModelDefinitionDoc::getCompressedString)
.collect(Collectors.joining());
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
listener.onFailure(ExceptionsHelper.serverError(
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
return;
}
builder.setDefinitionFromString(doc.getCompressedString());
builder.setDefinitionFromString(compressedString);
} catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
@ -678,13 +711,36 @@ public class TrainedModelProvider {
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
String resourceId,
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
return handleSearchItems(item, resourceId, parseLeniently).get(0);
}
// NOTE: This ignores any results that are in a different index than the first one seen in the search response.
private static <T> List<T> handleSearchItems(MultiSearchResponse.Item item,
String resourceId,
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
if (item.isFailure()) {
throw item.getFailure();
}
if (item.getResponse().getHits().getHits().length == 0) {
throw new ResourceNotFoundException(resourceId);
}
return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
List<T> results = new ArrayList<>(item.getResponse().getHits().getHits().length);
String initialIndex = item.getResponse().getHits().getHits()[0].getIndex();
for (SearchHit hit : item.getResponse().getHits().getHits()) {
// We don't want to spread across multiple backing indices
if (hit.getIndex().equals(initialIndex)) {
results.add(parseLeniently.apply(hit.getSourceRef(), resourceId));
}
}
return results;
}
static List<String> chunkStringWithSize(String str, int chunkSize) {
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
for (int i = 0; i < str.length();i += chunkSize) {
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
}
return subStrings;
}
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {