From c64e283dbf28cbdc79f1b9e53e7ff42bace546ae Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 1 Jul 2020 15:14:31 -0400 Subject: [PATCH] [7.x] [ML] handles compressed model stream from native process (#58009) (#58836) * [ML] handles compressed model stream from native process (#58009) This moves model storage from handling the fully parsed JSON string to handling two separate types of documents. 1. ModelSizeInfo which contains model size information 2. TrainedModelDefinitionChunk which contains a particular chunk of the compressed model definition string. `model_size_info` is assumed to be handled first. This will generate the model_id and store the initial trained model config object. Then each chunk is assumed to be in correct order for concatenating the chunks to get a compressed definition. Native side change: https://github.com/elastic/ml-cpp/pull/1349 --- .../xpack/core/ml/job/messages/Messages.java | 1 + .../ml/integration/ClassificationIT.java | 2 - .../xpack/ml/integration/RegressionIT.java | 2 - .../ChunkedTrainedModelPersisterIT.java | 130 ++++++++++ .../integration/TrainedModelProviderIT.java | 52 +++- .../process/AnalyticsResultProcessor.java | 121 ++------- .../process/ChunkedTrainedModelPersister.java | 235 ++++++++++++++++++ .../process/results/AnalyticsResult.java | 53 ++-- .../results/TrainedModelDefinitionChunk.java | 89 +++++++ .../TrainedModelDefinitionDoc.java | 40 ++- .../persistence/TrainedModelProvider.java | 146 +++++++++-- .../AnalyticsResultProcessorTests.java | 103 -------- .../ChunkedTrainedModelPersisterTests.java | 150 +++++++++++ .../process/results/AnalyticsResultTests.java | 15 +- 14 files changed, 853 insertions(+), 286 deletions(-) create mode 100644 x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 868dce9042c..6bcf179578f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -89,6 +89,7 @@ public final class Messages { " (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric"; public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; + public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 61b607825e8..866b7d0c1cf 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionModule; @@ -67,7 +66,6 @@ import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349") public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String BOOLEAN_FIELD = "boolean-field"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index bbd3478e41e..852aee7f97d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.ml.integration; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionModule; import org.elasticsearch.action.DocWriteRequest; @@ -45,7 +44,6 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349") public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String NUMERICAL_FEATURE_FIELD = "feature"; diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java new file mode 100644 index 00000000000..74581ac3d45 --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.license.License; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; +import org.elasticsearch.xpack.ml.extractor.DocValueField; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + public void testStoreModelViaChunkedPersister() throws IOException { + String modelId = "stored-chunked-model"; + DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder() + .setId(modelId) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)) + .setAnalysis(new Regression("foo")) + .build(); + List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); + TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); + String compressedDefinition = configBuilder.build().getCompressedDefinition(); + int totalSize = compressedDefinition.length(); + List chunks = chunkStringWithSize(compressedDefinition, totalSize/3); + + ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider, + analyticsConfig, + new DataFrameAnalyticsAuditor(client(), "test-node"), + (ex) -> { throw new ElasticsearchException(ex); }, + new ExtractedFields(extractedFieldList, Collections.emptyMap()) + ); + + //Accuracy for size is not tested here + ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); + persister.createAndIndexInferenceModelMetadata(modelSizeInfo); + for (int i = 0; i < chunks.size(); i++) { + persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1))); + } + + PlainActionFuture>> getIdsFuture = new PlainActionFuture<>(); + trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); + Tuple> ids = getIdsFuture.actionGet(); + assertThat(ids.v1(), equalTo(1L)); + + PlainActionFuture getTrainedModelFuture = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture); + + TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet(); + assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); + assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); + assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { + TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder(); + long bytesUsed = definitionBuilder.build().ramBytesUsed(); + long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations(); + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION)) + .setDescription("trained model config for test") + .setModelId(modelId) + .setVersion(Version.CURRENT) + .setLicenseLevel(License.OperationMode.PLATINUM.description()) + .setEstimatedHeapMemory(bytesUsed) + .setEstimatedOperations(operations) + .setInput(TrainedModelInputTests.createRandomInput()); + } + + public static List chunkStringWithSize(String str, int chunkSize) { + List subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize); + for (int i = 0; i < str.length(); i += chunkSize) { + subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length()))); + } + return subStrings; + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index b518e01f99d..b920c5686dd 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -32,8 +32,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; +import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; @@ -157,8 +160,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); } - public void testGetTruncatedModelDefinition() throws Exception { - String modelId = "test-get-truncated-model-config"; + public void testGetTruncatedModelDeprecatedDefinition() throws Exception { + String modelId = "test-get-truncated-legacy-model-config"; TrainedModelConfig config = buildTrainedModelConfig(modelId); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -196,6 +199,51 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase { assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); } + public void testGetTruncatedModelDefinition() throws Exception { + String modelId = "test-get-truncated-model-config"; + TrainedModelConfig config = buildTrainedModelConfig(modelId); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + List chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3); + + List docBuilders = IntStream.range(0, chunks.size()) + .mapToObj(i -> new TrainedModelDefinitionDoc.Builder() + .setDocNum(i) + .setCompressedString(chunks.get(i)) + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setDefinitionLength(chunks.get(i).length()) + .setEos(i == chunks.size() - 1) + .setModelId(modelId)) + .collect(Collectors.toList()); + boolean missingEos = randomBoolean(); + docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false); + for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) { + TrainedModelDefinitionDoc doc = docBuilders.get(i).build(); + try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(), + new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) { + AtomicReference putDocHolder = new AtomicReference<>(); + blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource(xContentBuilder) + .setId(TrainedModelDefinitionDoc.docId(modelId, 0)) + .execute(listener), + putDocHolder, + exceptionHolder); + assertThat(exceptionHolder.get(), is(nullValue())); + } + } + AtomicReference getConfigHolder = new AtomicReference<>(); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + assertThat(getConfigHolder.get(), is(nullValue())); + assertThat(exceptionHolder.get(), is(not(nullValue()))); + assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + } + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 86bc7d05d73..815540ac958 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -8,48 +8,29 @@ package org.elasticsearch.xpack.ml.dataframe.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.elasticsearch.Version; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.license.License; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; -import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; -import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; -import java.time.Instant; -import java.util.Collections; import java.util.Iterator; -import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import static java.util.stream.Collectors.toList; public class AnalyticsResultProcessor { @@ -70,11 +51,10 @@ public class AnalyticsResultProcessor { private final DataFrameAnalyticsConfig analytics; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final StatsHolder statsHolder; - private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; private final StatsPersister statsPersister; - private final ExtractedFields extractedFields; private final CountDownLatch completionLatch = new CountDownLatch(1); + private final ChunkedTrainedModelPersister chunkedTrainedModelPersister; private volatile String failure; private volatile boolean isCancelled; @@ -84,10 +64,15 @@ public class AnalyticsResultProcessor { this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.statsHolder = Objects.requireNonNull(statsHolder); - this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); this.statsPersister = Objects.requireNonNull(statsPersister); - this.extractedFields = Objects.requireNonNull(extractedFields); + this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister( + trainedModelProvider, + analytics, + auditor, + this::setAndReportFailure, + extractedFields + ); } @Nullable @@ -166,9 +151,13 @@ public class AnalyticsResultProcessor { phaseProgress.getProgressPercent()); statsHolder.getProgressTracker().updatePhase(phaseProgress); } - TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder(); - if (inferenceModelBuilder != null) { - createAndIndexInferenceModel(inferenceModelBuilder); + ModelSizeInfo modelSize = result.getModelSizeInfo(); + if (modelSize != null) { + chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize); + } + TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk(); + if (trainedModelDefinitionChunk != null) { + chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk); } MemoryUsage memoryUsage = result.getMemoryUsage(); if (memoryUsage != null) { @@ -191,82 +180,6 @@ public class AnalyticsResultProcessor { } } - private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) { - TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel); - CountDownLatch latch = storeTrainedModel(trainedModelConfig); - - try { - if (latch.await(30, TimeUnit.SECONDS) == false) { - LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId()); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored")); - } - } - - private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) { - Instant createTime = Instant.now(); - String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); - TrainedModelDefinition definition = inferenceModel.build(); - String dependentVariable = getDependentVariable(); - List fieldNames = extractedFields.getAllFields(); - List fieldNamesWithoutDependentVariable = fieldNames.stream() - .map(ExtractedField::getName) - .filter(f -> f.equals(dependentVariable) == false) - .collect(toList()); - Map defaultFieldMapping = fieldNames.stream() - .filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false)) - .collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName)); - return TrainedModelConfig.builder() - .setModelId(modelId) - .setCreatedBy(XPackUser.NAME) - .setVersion(Version.CURRENT) - .setCreateTime(createTime) - // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags - .setTags(Collections.singletonList(analytics.getId())) - .setDescription(analytics.getDescription()) - .setMetadata(Collections.singletonMap("analytics_config", - XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) - .setEstimatedHeapMemory(definition.ramBytesUsed()) - .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) - .setParsedDefinition(inferenceModel) - .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) - .setLicenseLevel(License.OperationMode.PLATINUM.description()) - .setDefaultFieldMap(defaultFieldMapping) - .setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields))) - .build(); - } - - private String getDependentVariable() { - if (analytics.getAnalysis() instanceof Classification) { - return ((Classification)analytics.getAnalysis()).getDependentVariable(); - } - if (analytics.getAnalysis() instanceof Regression) { - return ((Regression)analytics.getAnalysis()).getDependentVariable(); - } - return null; - } - - private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) { - CountDownLatch latch = new CountDownLatch(1); - ActionListener storeListener = ActionListener.wrap( - aBoolean -> { - if (aBoolean == false) { - LOGGER.error("[{}] Storing trained model responded false", analytics.getId()); - setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false")); - } else { - LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); - auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]"); - } - }, - e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e, - trainedModelConfig.getModelId())) - ); - trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); - return latch; - } - private void setAndReportFailure(Exception e) { LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e); failure = "error processing results; " + e.getMessage(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java new file mode 100644 index 00000000000..58e2227f4da --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.license.License; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.security.user.XPackUser; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.extractor.MultiField; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static java.util.stream.Collectors.toList; + +public class ChunkedTrainedModelPersister { + + private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class); + private static final int STORE_TIMEOUT_SEC = 30; + private final TrainedModelProvider provider; + private final AtomicReference currentModelId; + private final DataFrameAnalyticsConfig analytics; + private final DataFrameAnalyticsAuditor auditor; + private final Consumer failureHandler; + private final ExtractedFields extractedFields; + private final AtomicBoolean readyToStoreNewModel = new AtomicBoolean(true); + + public ChunkedTrainedModelPersister(TrainedModelProvider provider, + DataFrameAnalyticsConfig analytics, + DataFrameAnalyticsAuditor auditor, + Consumer failureHandler, + ExtractedFields extractedFields) { + this.provider = provider; + this.currentModelId = new AtomicReference<>(""); + this.analytics = analytics; + this.auditor = auditor; + this.failureHandler = failureHandler; + this.extractedFields = extractedFields; + } + + public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + if (Strings.isNullOrEmpty(this.currentModelId.get())) { + failureHandler.accept(ExceptionsHelper.serverError( + "chunked inference model definition is attempting to be stored before trained model configuration" + )); + return; + } + TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc(this.currentModelId.get()); + + CountDownLatch latch = storeTrainedModelDoc(trainedModelDefinitionDoc); + try { + if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for chunked inference definition to be stored", analytics.getId()); + if (trainedModelDefinitionChunk.isEos()) { + this.readyToStoreNewModel.set(true); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + this.readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for chunked inference definition to be stored")); + } + } + + public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) { + if (readyToStoreNewModel.compareAndSet(true, false) == false) { + failureHandler.accept(ExceptionsHelper.serverError( + "new inference model is attempting to be stored before completion previous model storage" + )); + return; + } + TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize); + CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig); + try { + if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + this.readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored")); + } + } + + private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) { + CountDownLatch latch = new CountDownLatch(1); + + // Latch is attached to this action as it is the last one to execute. + ActionListener refreshListener = new LatchedActionListener<>(ActionListener.wrap( + refreshed -> { + if (refreshed != null) { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] refreshed inference index after model store", + analytics.getId() + )); + } + }, + e -> LOGGER.warn( + new ParameterizedMessage("[{}] failed to refresh inference index after model store", analytics.getId()), + e) + ), latch); + + // First, store the model and refresh is necessary + ActionListener storeListener = ActionListener.wrap( + r -> { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] stored trained model definition chunk [{}] [{}]", + analytics.getId(), + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum())); + if (trainedModelDefinitionDoc.isEos() == false) { + refreshListener.onResponse(null); + return; + } + LOGGER.info( + "[{}] finished storing trained model with id [{}]", + analytics.getId(), + this.currentModelId.get()); + auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]"); + this.currentModelId.set(""); + readyToStoreNewModel.set(true); + provider.refreshInferenceIndex(refreshListener); + }, + e -> { + this.readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError( + "error storing trained model definition chunk [{}] with id [{}]", + e, + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum())); + refreshListener.onResponse(null); + } + ); + provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener); + return latch; + } + private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) { + CountDownLatch latch = new CountDownLatch(1); + ActionListener storeListener = ActionListener.wrap( + aBoolean -> { + if (aBoolean == false) { + LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId()); + readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false")); + } else { + LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); + } + }, + e -> { + readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]", + e, + trainedModelConfig.getModelId())); + } + ); + provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); + return latch; + } + + private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) { + Instant createTime = Instant.now(); + String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); + currentModelId.set(modelId); + List fieldNames = extractedFields.getAllFields(); + String dependentVariable = getDependentVariable(); + List fieldNamesWithoutDependentVariable = fieldNames.stream() + .map(ExtractedField::getName) + .filter(f -> f.equals(dependentVariable) == false) + .collect(toList()); + Map defaultFieldMapping = fieldNames.stream() + .filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false)) + .collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName)); + return TrainedModelConfig.builder() + .setModelId(modelId) + .setCreatedBy(XPackUser.NAME) + .setVersion(Version.CURRENT) + .setCreateTime(createTime) + // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags + .setTags(Collections.singletonList(analytics.getId())) + .setDescription(analytics.getDescription()) + .setMetadata(Collections.singletonMap("analytics_config", + XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) + .setEstimatedHeapMemory(modelSize.ramBytesUsed()) + .setEstimatedOperations(modelSize.numOperations()) + .setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable)) + .setLicenseLevel(License.OperationMode.PLATINUM.description()) + .setDefaultFieldMap(defaultFieldMapping) + .setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields))) + .build(); + } + + private String getDependentVariable() { + if (analytics.getAnalysis() instanceof Classification) { + return ((Classification)analytics.getAnalysis()).getDependentVariable(); + } + if (analytics.getAnalysis() instanceof Regression) { + return ((Regression)analytics.getAnalysis()).getDependentVariable(); + } + return null; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 0a05f13c118..c7e6f9a1c43 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -8,31 +8,27 @@ package org.elasticsearch.xpack.ml.dataframe.process.results; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; -import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import java.io.IOException; -import java.util.Collections; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; -import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE; public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress"); - private static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info"); + private static final ParseField COMPRESSED_INFERENCE_MODEL = new ParseField("compressed_inference_model"); private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage"); private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); @@ -42,53 +38,50 @@ public class AnalyticsResult implements ToXContentObject { a -> new AnalyticsResult( (RowResults) a[0], (PhaseProgress) a[1], - (TrainedModelDefinition.Builder) a[2], - (MemoryUsage) a[3], - (OutlierDetectionStats) a[4], - (ClassificationStats) a[5], - (RegressionStats) a[6], - (ModelSizeInfo) a[7] + (MemoryUsage) a[2], + (OutlierDetectionStats) a[3], + (ClassificationStats) a[4], + (RegressionStats) a[5], + (ModelSizeInfo) a[6], + (TrainedModelDefinitionChunk) a[7] )); static { PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS); - // TODO change back to STRICT_PARSER once native side is aligned - PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL); PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE); PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS); PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS); PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO); + PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL); } private final RowResults rowResults; private final PhaseProgress phaseProgress; - private final TrainedModelDefinition.Builder inferenceModelBuilder; - private final TrainedModelDefinition inferenceModel; private final MemoryUsage memoryUsage; private final OutlierDetectionStats outlierDetectionStats; private final ClassificationStats classificationStats; private final RegressionStats regressionStats; private final ModelSizeInfo modelSizeInfo; + private final TrainedModelDefinitionChunk trainedModelDefinitionChunk; public AnalyticsResult(@Nullable RowResults rowResults, @Nullable PhaseProgress phaseProgress, - @Nullable TrainedModelDefinition.Builder inferenceModelBuilder, @Nullable MemoryUsage memoryUsage, @Nullable OutlierDetectionStats outlierDetectionStats, @Nullable ClassificationStats classificationStats, @Nullable RegressionStats regressionStats, - @Nullable ModelSizeInfo modelSizeInfo) { + @Nullable ModelSizeInfo modelSizeInfo, + @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; - this.inferenceModelBuilder = inferenceModelBuilder; - this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build(); this.memoryUsage = memoryUsage; this.outlierDetectionStats = outlierDetectionStats; this.classificationStats = classificationStats; this.regressionStats = regressionStats; this.modelSizeInfo = modelSizeInfo; + this.trainedModelDefinitionChunk = trainedModelDefinitionChunk; } public RowResults getRowResults() { @@ -99,10 +92,6 @@ public class AnalyticsResult implements ToXContentObject { return phaseProgress; } - public TrainedModelDefinition.Builder getInferenceModelBuilder() { - return inferenceModelBuilder; - } - public MemoryUsage getMemoryUsage() { return memoryUsage; } @@ -123,6 +112,10 @@ public class AnalyticsResult implements ToXContentObject { return modelSizeInfo; } + public TrainedModelDefinitionChunk getTrainedModelDefinitionChunk() { + return trainedModelDefinitionChunk; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -132,11 +125,6 @@ public class AnalyticsResult implements ToXContentObject { if (phaseProgress != null) { builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress); } - if (inferenceModel != null) { - builder.field(INFERENCE_MODEL.getPreferredName(), - inferenceModel, - new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"))); - } if (memoryUsage != null) { builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params); } @@ -152,6 +140,9 @@ public class AnalyticsResult implements ToXContentObject { if (modelSizeInfo != null) { builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo); } + if (trainedModelDefinitionChunk != null) { + builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk); + } builder.endObject(); return builder; } @@ -168,17 +159,17 @@ public class AnalyticsResult implements ToXContentObject { AnalyticsResult that = (AnalyticsResult) other; return Objects.equals(rowResults, that.rowResults) && Objects.equals(phaseProgress, that.phaseProgress) - && Objects.equals(inferenceModel, that.inferenceModel) && Objects.equals(memoryUsage, that.memoryUsage) && Objects.equals(outlierDetectionStats, that.outlierDetectionStats) && Objects.equals(classificationStats, that.classificationStats) && Objects.equals(modelSizeInfo, that.modelSizeInfo) + && Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk) && Objects.equals(regressionStats, that.regressionStats); } @Override public int hashCode() { - return Objects.hash(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats, - regressionStats); + return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats, + regressionStats, modelSizeInfo, trainedModelDefinitionChunk); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java new file mode 100644 index 00000000000..3d5ce84a6af --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/TrainedModelDefinitionChunk.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TrainedModelDefinitionChunk implements ToXContentObject { + + private static final ParseField DEFINITION = new ParseField("definition"); + private static final ParseField DOC_NUM = new ParseField("doc_num"); + private static final ParseField EOS = new ParseField("eos"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "chunked_trained_model_definition", + a -> new TrainedModelDefinitionChunk((String) a[0], (Integer) a[1], (Boolean) a[2])); + + static { + PARSER.declareString(constructorArg(), DEFINITION); + PARSER.declareInt(constructorArg(), DOC_NUM); + PARSER.declareBoolean(optionalConstructorArg(), EOS); + } + + private final String definition; + private final int docNum; + private final Boolean eos; + + public TrainedModelDefinitionChunk(String definition, int docNum, Boolean eos) { + this.definition = definition; + this.docNum = docNum; + this.eos = eos; + } + + public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId) { + return new TrainedModelDefinitionDoc.Builder() + .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) + .setModelId(modelId) + .setDefinitionLength(definition.length()) + .setDocNum(docNum) + .setCompressedString(definition) + .setEos(isEos()) + .build(); + } + + public boolean isEos() { + return eos != null && eos; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DEFINITION.getPreferredName(), definition); + builder.field(DOC_NUM.getPreferredName(), docNum); + if (eos != null) { + builder.field(EOS.getPreferredName(), eos); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDefinitionChunk that = (TrainedModelDefinitionChunk) o; + return docNum == that.docNum + && Objects.equals(definition, that.definition) + && Objects.equals(eos, that.eos); + } + + @Override + public int hashCode() { + return Objects.hash(definition, docNum, eos); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java index b2b53ec445c..de332b77af6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java @@ -33,6 +33,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version"); public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length"); public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length"); + public static final ParseField EOS = new ParseField("eos"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -48,6 +49,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION); parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH); parser.declareLong(TrainedModelDefinitionDoc.Builder::setTotalDefinitionLength, TOTAL_DEFINITION_LENGTH); + parser.declareBoolean(TrainedModelDefinitionDoc.Builder::setEos, EOS); return parser; } @@ -63,23 +65,26 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { private final String compressedString; private final String modelId; private final int docNum; - private final long totalDefinitionLength; + // for BWC + private final Long totalDefinitionLength; private final long definitionLength; private final int compressionVersion; + private final boolean eos; private TrainedModelDefinitionDoc(String compressedString, String modelId, int docNum, - long totalDefinitionLength, + Long totalDefinitionLength, long definitionLength, - int compressionVersion) { + int compressionVersion, + boolean eos) { this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION); this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); if (docNum < 0) { throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0"); } this.docNum = docNum; - if (totalDefinitionLength <= 0L) { + if (totalDefinitionLength != null && totalDefinitionLength <= 0L) { throw new IllegalArgumentException("[total_definition_length] must be greater than 0"); } this.totalDefinitionLength = totalDefinitionLength; @@ -88,6 +93,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { } this.definitionLength = definitionLength; this.compressionVersion = compressionVersion; + this.eos = eos; } public String getCompressedString() { @@ -102,7 +108,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { return docNum; } - public long getTotalDefinitionLength() { + public Long getTotalDefinitionLength() { return totalDefinitionLength; } @@ -114,16 +120,24 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { return compressionVersion; } + public boolean isEos() { + return eos; + } + + public String getDocId() { + return docId(modelId, docNum); + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); builder.field(DOC_NUM.getPreferredName(), docNum); - builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength); builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength); builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion); builder.field(DEFINITION.getPreferredName(), compressedString); + builder.field(EOS.getPreferredName(), eos); builder.endObject(); return builder; } @@ -143,12 +157,13 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { Objects.equals(definitionLength, that.definitionLength) && Objects.equals(totalDefinitionLength, that.totalDefinitionLength) && Objects.equals(compressionVersion, that.compressionVersion) && + Objects.equals(eos, that.eos) && Objects.equals(compressedString, that.compressedString); } @Override public int hashCode() { - return Objects.hash(modelId, docNum, totalDefinitionLength, definitionLength, compressionVersion, compressedString); + return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos); } public static class Builder { @@ -156,9 +171,10 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { private String modelId; private String compressedString; private int docNum; - private long totalDefinitionLength; + private Long totalDefinitionLength; private long definitionLength; private int compressionVersion; + private boolean eos; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -190,6 +206,11 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { return this; } + public Builder setEos(boolean eos) { + this.eos = eos; + return this; + } + public TrainedModelDefinitionDoc build() { return new TrainedModelDefinitionDoc( this.compressedString, @@ -197,7 +218,8 @@ public class TrainedModelDefinitionDoc implements ToXContentObject { this.docNum, this.totalDefinitionLength, this.definitionLength, - this.compressionVersion); + this.compressionVersion, + this.eos); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index a759e10f24d..26b8cae6f0e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -14,10 +14,14 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshAction; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchRequest; @@ -143,6 +147,74 @@ public class TrainedModelProvider { storeTrainedModelAndDefinition(trainedModelConfig, listener); } + public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig, + ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + return; + } + assert trainedModelConfig.getModelDefinition() == null; + + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig), + ActionListener.wrap( + indexResponse -> listener.onResponse(true), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelConfig.getModelId())); + } + } + )); + } + + public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc, ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(trainedModelDefinitionDoc.getModelId())) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelDefinitionDoc.getModelId()))); + return; + } + + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(trainedModelDefinitionDoc.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelDefinitionDoc), + ActionListener.wrap( + indexResponse -> listener.onResponse(null), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_DOC_EXISTS, + trainedModelDefinitionDoc.getModelId(), + trainedModelDefinitionDoc.getDocNum()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelDefinitionDoc.getModelId())); + } + } + )); + } + + public void refreshInferenceIndex(ActionListener listener) { + executeAsyncWithOrigin(client, + ML_ORIGIN, + RefreshAction.INSTANCE, + new RefreshRequest(InferenceIndexConstants.INDEX_PATTERN), + listener); + } + private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig, ActionListener listener) { @@ -165,7 +237,8 @@ public class TrainedModelProvider { .setCompressedString(chunkedStrings.get(i)) .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION) .setDefinitionLength(chunkedStrings.get(i).length()) - .setTotalDefinitionLength(compressedString.length()) + // If it is the last doc, it is the EOS + .setEos(i == chunkedStrings.size() - 1) .build()); } } catch (IOException ex) { @@ -265,6 +338,9 @@ public class TrainedModelProvider { .unmappedType("long")) .request(); executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + // TODO how could we stream in the model definition WHILE parsing it? + // This would reduce the overall memory usage as we won't have to load the whole compressed string + // XContentParser supports streams. searchResponse -> { if (searchResponse.getHits().getHits().length == 0) { listener.onFailure(new ResourceNotFoundException( @@ -274,19 +350,16 @@ public class TrainedModelProvider { List docs = handleHits(searchResponse.getHits().getHits(), modelId, this::parseModelDefinitionDocLenientlyFromSource); - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; + try { + String compressedString = getDefinitionFromDocs(docs, modelId); + InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( + compressedString, + InferenceDefinition::fromXContent, + xContentRegistry); + listener.onResponse(inferenceDefinition); + } catch (ElasticsearchException elasticsearchException) { + listener.onFailure(elasticsearchException); } - InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, - InferenceDefinition::fromXContent, - xContentRegistry); - listener.onResponse(inferenceDefinition); }, e -> { if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { @@ -361,15 +434,14 @@ public class TrainedModelProvider { List docs = handleSearchItems(multiSearchResponse.getResponses()[1], modelId, this::parseModelDefinitionDocLenientlyFromSource); - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + try { + String compressedString = getDefinitionFromDocs(docs, modelId); + builder.setDefinitionFromString(compressedString); + } catch (ElasticsearchException elasticsearchException) { + listener.onFailure(elasticsearchException); return; } - builder.setDefinitionFromString(compressedString); + } catch (ResourceNotFoundException ex) { listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); @@ -806,6 +878,26 @@ public class TrainedModelProvider { return results; } + private static String getDefinitionFromDocs(List docs, String modelId) throws ElasticsearchException { + String compressedString = docs.stream() + .map(TrainedModelDefinitionDoc::getCompressedString) + .collect(Collectors.joining()); + // BWC for when we tracked the total definition length + // TODO: remove in 9 + if (docs.get(0).getTotalDefinitionLength() != null) { + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); + } + } else { + TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1); + // Either we are missing the last doc, or some previous doc + if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) { + throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)); + } + } + return compressedString; + } + static List chunkStringWithSize(String str, int chunkSize) { List subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize)); for (int i = 0; i < str.length();i += chunkSize) { @@ -836,14 +928,18 @@ public class TrainedModelProvider { } } + private IndexRequest createRequest(String docId, String index, ToXContentObject body) { + return createRequest(new IndexRequest(index), docId, body); + } + private IndexRequest createRequest(String docId, ToXContentObject body) { + return createRequest(new IndexRequest(), docId, body); + } + + private IndexRequest createRequest(IndexRequest request, String docId, ToXContentObject body) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS); - - return new IndexRequest() - .opType(DocWriteRequest.OpType.CREATE) - .id(docId) - .source(source); + return request.opType(DocWriteRequest.OpType.CREATE).id(docId).source(source); } catch (IOException ex) { // This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again // that is not the users fault. We did something wrong and should throw. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 506419c53d6..5c3e2b9ce19 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,32 +5,20 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; -import org.elasticsearch.Version; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.license.License; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; -import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; -import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; -import org.elasticsearch.xpack.ml.extractor.MultiField; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.junit.Before; @@ -38,20 +26,14 @@ import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Mockito; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasKey; -import static org.hamcrest.Matchers.startsWith; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -156,90 +138,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase { assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } - @SuppressWarnings("unchecked") - public void testProcess_GivenInferenceModelIsStoredSuccessfully() { - givenDataFrameRows(0); - - doAnswer(invocationOnMock -> { - ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; - storeListener.onResponse(true); - return null; - }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - - List extractedFieldList = new ArrayList<>(3); - extractedFieldList.add(new DocValueField("foo", Collections.emptySet())); - extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet()))); - extractedFieldList.add(new DocValueField("baz", Collections.emptySet())); - TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList); - - resultProcessor.process(process); - resultProcessor.awaitForCompletion(); - - ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); - verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class)); - - TrainedModelConfig storedModel = storedModelCaptor.getValue(); - assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM)); - assertThat(storedModel.getModelId(), containsString(JOB_ID)); - assertThat(storedModel.getVersion(), equalTo(Version.CURRENT)); - assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME)); - assertThat(storedModel.getTags(), contains(JOB_ID)); - assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); - assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build())); - assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword"))); - assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz"))); - assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); - assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); - if (targetType.equals(TargetType.CLASSIFICATION)) { - assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); - } else { - assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); - } - Map metadata = storedModel.getMetadata(); - assertThat(metadata.size(), equalTo(1)); - assertThat(metadata, hasKey("analytics_config")); - Map analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(), - true); - assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config"))); - - ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); - verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); - Mockito.verifyNoMoreInteractions(auditor); - } - - - @SuppressWarnings("unchecked") - public void testProcess_GivenInferenceModelFailedToStore() { - givenDataFrameRows(0); - - doAnswer(invocationOnMock -> { - ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; - storeListener.onFailure(new RuntimeException("some failure")); - return null; - }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - - TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION; - TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType); - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(); - - resultProcessor.process(process); - resultProcessor.awaitForCompletion(); - - // This test verifies the processor knows how to handle a failure on storing the model and completes normally - ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); - verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID)); - Mockito.verifyNoMoreInteractions(auditor); - - assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID)); - assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); - } - private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } @@ -256,7 +154,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase { } private AnalyticsResultProcessor createResultProcessor(List fieldNames) { - return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, statsHolder, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java new file mode 100644 index 00000000000..ee01e297907 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.license.License; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.security.user.XPackUser; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; +import org.elasticsearch.xpack.ml.extractor.DocValueField; +import org.elasticsearch.xpack.ml.extractor.ExtractedField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; +import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ChunkedTrainedModelPersisterTests extends ESTestCase { + + private static final String JOB_ID = "analytics-result-processor-tests"; + private static final String JOB_DESCRIPTION = "This describes the job of these tests"; + + private TrainedModelProvider trainedModelProvider; + private DataFrameAnalyticsAuditor auditor; + + @Before + public void setUpMocks() { + trainedModelProvider = mock(TrainedModelProvider.class); + auditor = mock(DataFrameAnalyticsAuditor.class); + } + + @SuppressWarnings("unchecked") + public void testPersistAllDocs() { + DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder() + .setId(JOB_ID) + .setDescription(JOB_DESCRIPTION) + .setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null)) + .setDest(new DataFrameAnalyticsDest("my_dest", null)) + .setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo")) + .build(); + List extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet())); + + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onResponse(true); + return null; + }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class)); + + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onResponse(null); + return null; + }).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class)); + + ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); + ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); + TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false); + TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true); + + resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo); + resultProcessor.createAndIndexInferenceModelDoc(chunk1); + resultProcessor.createAndIndexInferenceModelDoc(chunk2); + + ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); + verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class)); + + ArgumentCaptor storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class); + verify(trainedModelProvider, times(2)) + .storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class)); + + TrainedModelConfig storedModel = storedModelCaptor.getValue(); + assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM)); + assertThat(storedModel.getModelId(), containsString(JOB_ID)); + assertThat(storedModel.getVersion(), equalTo(Version.CURRENT)); + assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME)); + assertThat(storedModel.getTags(), contains(JOB_ID)); + assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); + assertThat(storedModel.getModelDefinition(), is(nullValue())); + assertThat(storedModel.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + assertThat(storedModel.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); + if (analyticsConfig.getAnalysis() instanceof Classification) { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification")); + } else { + assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression")); + } + Map metadata = storedModel.getMetadata(); + assertThat(metadata.size(), equalTo(1)); + assertThat(metadata, hasKey("analytics_config")); + Map analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(), + true); + assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config"))); + + TrainedModelDefinitionDoc storedDoc1 = storedDocCapture.getAllValues().get(0); + assertThat(storedDoc1.getDocNum(), equalTo(0)); + TrainedModelDefinitionDoc storedDoc2 = storedDocCapture.getAllValues().get(1); + assertThat(storedDoc2.getDocNum(), equalTo(1)); + + assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId())); + assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId())); + + ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); + verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); + assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); + Mockito.verifyNoMoreInteractions(auditor); + } + + private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List fieldNames, + DataFrameAnalyticsConfig analyticsConfig) { + return new ChunkedTrainedModelPersister(trainedModelProvider, + analyticsConfig, + auditor, + (unused)->{}, + new ExtractedFields(fieldNames, Collections.emptyMap())); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 2bf72318295..3f48583ef34 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -20,8 +20,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; @@ -46,21 +44,18 @@ public class AnalyticsResultTests extends AbstractXContentTestCase