From 24d41eb6958226dd1fc0e81a85b680ba926e3730 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 20 Apr 2020 16:08:54 -0400 Subject: [PATCH] [ML] partitions model definitions into chunks (#55260) (#55484) This paves the data layer way so that exceptionally large models are partitioned across multiple documents. This change means that nodes before 7.8.0 will not be able to use trained inference models created on nodes on or after 7.8.0. I chose the definition document limit to be 100. This *SHOULD* be plenty for any large model. One of the largest models that I have created so far had the following stats: ~314MB of inflated JSON, ~66MB when compressed, ~177MB of heap. With the chunking sizes of `16 * 1024 * 1024` its compressed string could be partitioned to 5 documents. Supporting models 20 times this size (compressed) seems adequate for now. --- .../df-analytics/apis/put-inference.asciidoc | 8 +- .../TransportPutTrainedModelAction.java | 13 +- .../persistence/TrainedModelProvider.java | 118 +++++++++++++----- 3 files changed, 106 insertions(+), 33 deletions(-) diff --git a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc index 5521ccdde67..5928f9ddbdb 100644 --- a/docs/reference/ml/df-analytics/apis/put-inference.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-inference.asciidoc @@ -8,7 +8,13 @@ ++++ Creates an {infer} trained model. - ++ +-- +WARNING: Models created in version 7.8.0 are not backwards compatible + with older node versions. If in a mixed cluster environment, + all nodes must be at least 7.8.0 to use a model stored by + a 7.8.0 node. +-- experimental[] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 503393a836a..ea48d2809e5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -78,7 +78,18 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction listener) { + protected void masterOperation(Request request, + ClusterState state, + ActionListener listener) { + // 7.8.0 introduced splitting the model definition across multiple documents. + // This means that new models will not be usable on nodes that cannot handle multiple definition documents + if (state.nodes().getMinNodeVersion().before(Version.V_7_8_0)) { + listener.onFailure(ExceptionsHelper.badRequestException( + "Creating a new model requires that all nodes are at least version [{}]", + request.getTrainedModelConfig().getModelId(), + Version.V_7_8_0.toString())); + return; + } try { request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry); request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 7820eff0c09..e4b87406dab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -15,7 +15,8 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.bulk.BulkAction; -import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.MultiSearchAction; @@ -86,6 +87,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeSet; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -96,6 +98,9 @@ public class TrainedModelProvider { public static final Set MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1"); private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/"; private static final String MODEL_RESOURCE_FILE_EXT = ".json"; + private static final int COMPRESSED_STRING_CHUNK_SIZE = 16 * 1024 * 1024; + private static final int MAX_NUM_DEFINITION_DOCS = 100; + private static final int MAX_COMPRESSED_STRING_SIZE = COMPRESSED_STRING_CHUNK_SIZE * MAX_NUM_DEFINITION_DOCS; private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; @@ -139,30 +144,41 @@ public class TrainedModelProvider { private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig, ActionListener listener) { - TrainedModelDefinitionDoc trainedModelDefinitionDoc; + List trainedModelDefinitionDocs = new ArrayList<>(); try { - // TODO should we check length against allowed stream size??? String compressedString = trainedModelConfig.getCompressedDefinition(); - trainedModelDefinitionDoc = new TrainedModelDefinitionDoc.Builder() - .setDocNum(0) - .setModelId(trainedModelConfig.getModelId()) - .setCompressedString(compressedString) - .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) - .setDefinitionLength(compressedString.length()) - .setTotalDefinitionLength(compressedString.length()) - .build(); + if (compressedString.length() > MAX_COMPRESSED_STRING_SIZE) { + listener.onFailure( + ExceptionsHelper.badRequestException( + "Unable to store model as compressed definition has length [{}] the limit is [{}]", + compressedString.length(), + MAX_COMPRESSED_STRING_SIZE)); + return; + } + List chunkedStrings = chunkStringWithSize(compressedString, COMPRESSED_STRING_CHUNK_SIZE); + for(int i = 0; i < chunkedStrings.size(); ++i) { + trainedModelDefinitionDocs.add(new TrainedModelDefinitionDoc.Builder() + .setDocNum(i) + .setModelId(trainedModelConfig.getModelId()) + .setCompressedString(chunkedStrings.get(i)) + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setDefinitionLength(chunkedStrings.get(i).length()) + .setTotalDefinitionLength(compressedString.length()) + .build()); + } } catch (IOException ex) { listener.onFailure(ExceptionsHelper.serverError( - "Unexpected IOException while serializing definition for storage for model [" + trainedModelConfig.getModelId() + "]", - ex)); + "Unexpected IOException while serializing definition for storage for model [{}]", + ex, + trainedModelConfig.getModelId())); return; } - BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + BulkRequestBuilder bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig)) - .add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), 0), trainedModelDefinitionDoc)) - .request(); + .add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig)); + trainedModelDefinitionDocs.forEach(defDoc -> + bulkRequest.add(createRequest(TrainedModelDefinitionDoc.docId(trainedModelConfig.getModelId(), defDoc.getDocNum()), defDoc))); ActionListener wrappedListener = ActionListener.wrap( listener::onResponse, @@ -182,9 +198,8 @@ public class TrainedModelProvider { ActionListener bulkResponseActionListener = ActionListener.wrap( r -> { - assert r.getItems().length == 2; + assert r.getItems().length == trainedModelDefinitionDocs.size() + 1; if (r.getItems()[0].isFailed()) { - logger.error(new ParameterizedMessage( "[{}] failed to store trained model config for inference", trainedModelConfig.getModelId()), @@ -193,12 +208,18 @@ public class TrainedModelProvider { wrappedListener.onFailure(r.getItems()[0].getFailure().getCause()); return; } - if (r.getItems()[1].isFailed()) { + if (r.hasFailures()) { + Exception firstFailure = Arrays.stream(r.getItems()) + .filter(BulkItemResponse::isFailed) + .map(BulkItemResponse::getFailure) + .map(BulkItemResponse.Failure::getCause) + .findFirst() + .orElse(new Exception("unknown failure")); logger.error(new ParameterizedMessage( "[{}] failed to store trained model definition for inference", trainedModelConfig.getModelId()), - r.getItems()[1].getFailure().getCause()); - wrappedListener.onFailure(r.getItems()[1].getFailure().getCause()); + firstFailure); + wrappedListener.onFailure(firstFailure); return; } wrappedListener.onResponse(true); @@ -206,7 +227,7 @@ public class TrainedModelProvider { wrappedListener::onFailure ); - executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener); + executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener); } public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener listener) { @@ -235,11 +256,20 @@ public class TrainedModelProvider { if (includeDefinition) { multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders - .idsQuery() - .addIds(TrainedModelDefinitionDoc.docId(modelId, 0)))) - // use sort to get the last + .boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)))) + // There should be AT MOST these many docs. There might be more if definitions have been reindex to newer indices + // If this ends up getting duplicate groups of definition documents, the parsing logic will throw away any doc that + // is in a different index than the first index seen. + .setSize(MAX_NUM_DEFINITION_DOCS) + // First find the latest index .addSort("_index", SortOrder.DESC) - .setSize(1) + // Then, sort by doc_num + .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) + .order(SortOrder.ASC) + // We need this for the search not to fail when there are no mappings yet in the index + .unmappedType("long")) .request()); } @@ -259,15 +289,18 @@ public class TrainedModelProvider { if (includeDefinition) { try { - TrainedModelDefinitionDoc doc = handleSearchItem(multiSearchResponse.getResponses()[1], + List docs = handleSearchItems(multiSearchResponse.getResponses()[1], modelId, this::parseModelDefinitionDocLenientlyFromSource); - if (doc.getCompressedString().length() != doc.getTotalDefinitionLength()) { + String compressedString = docs.stream() + .map(TrainedModelDefinitionDoc::getCompressedString) + .collect(Collectors.joining()); + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { listener.onFailure(ExceptionsHelper.serverError( Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); return; } - builder.setDefinitionFromString(doc.getCompressedString()); + builder.setDefinitionFromString(compressedString); } catch (ResourceNotFoundException ex) { listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); @@ -678,13 +711,36 @@ public class TrainedModelProvider { private static T handleSearchItem(MultiSearchResponse.Item item, String resourceId, CheckedBiFunction parseLeniently) throws Exception { + return handleSearchItems(item, resourceId, parseLeniently).get(0); + } + + // NOTE: This ignores any results that are in a different index than the first one seen in the search response. + private static List handleSearchItems(MultiSearchResponse.Item item, + String resourceId, + CheckedBiFunction parseLeniently) throws Exception { if (item.isFailure()) { throw item.getFailure(); } if (item.getResponse().getHits().getHits().length == 0) { throw new ResourceNotFoundException(resourceId); } - return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId); + List results = new ArrayList<>(item.getResponse().getHits().getHits().length); + String initialIndex = item.getResponse().getHits().getHits()[0].getIndex(); + for (SearchHit hit : item.getResponse().getHits().getHits()) { + // We don't want to spread across multiple backing indices + if (hit.getIndex().equals(initialIndex)) { + results.add(parseLeniently.apply(hit.getSourceRef(), resourceId)); + } + } + return results; + } + + static List chunkStringWithSize(String str, int chunkSize) { + List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); + for (int i = 0; i < str.length();i += chunkSize) { + subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + } + return subStrings; } private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {