[7.x] [ML] handles compressed model stream from native process (#58009) (#58836)

* [ML] handles compressed model stream from native process (#58009)

This moves model storage from handling the fully parsed JSON string to handling two separate types of documents.

1. ModelSizeInfo which contains model size information 
2. TrainedModelDefinitionChunk which contains a particular chunk of the compressed model definition string.

`model_size_info` is assumed to be handled first. This will generate the model_id and store the initial trained model config object. Then each chunk is assumed to be in correct order for concatenating the chunks to get a compressed definition.


Native side change: https://github.com/elastic/ml-cpp/pull/1349
This commit is contained in:
Benjamin Trent 2020-07-01 15:14:31 -04:00 committed by GitHub
parent 9c77862a23
commit c64e283dbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 853 additions and 286 deletions

View File

@ -89,6 +89,7 @@ public final class Messages {
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.integration;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionModule;
@ -67,7 +66,6 @@ import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String BOOLEAN_FIELD = "boolean-field";

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.ml.integration;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionModule;
import org.elasticsearch.action.DocWriteRequest;
@ -45,7 +44,6 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String NUMERICAL_FEATURE_FIELD = "feature";

View File

@ -0,0 +1,130 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import static org.hamcrest.Matchers.equalTo;
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
private TrainedModelProvider trainedModelProvider;
@Before
public void createComponents() throws Exception {
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
waitForMlTemplates();
}
public void testStoreModelViaChunkedPersister() throws IOException {
String modelId = "stored-chunked-model";
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
.setId(modelId)
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
.setDest(new DataFrameAnalyticsDest("my_dest", null))
.setAnalysis(new Regression("foo"))
.build();
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
String compressedDefinition = configBuilder.build().getCompressedDefinition();
int totalSize = compressedDefinition.length();
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);
ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
analyticsConfig,
new DataFrameAnalyticsAuditor(client(), "test-node"),
(ex) -> { throw new ElasticsearchException(ex); },
new ExtractedFields(extractedFieldList, Collections.emptyMap())
);
//Accuracy for size is not tested here
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
for (int i = 0; i < chunks.size(); i++) {
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
}
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
assertThat(ids.v1(), equalTo(1L));
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
}
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
long bytesUsed = definitionBuilder.build().ramBytesUsed();
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(bytesUsed)
.setEstimatedOperations(operations)
.setInput(TrainedModelInputTests.createRandomInput());
}
public static List<String> chunkStringWithSize(String str, int chunkSize) {
List<String> subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize);
for (int i = 0; i < str.length(); i += chunkSize) {
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
}
return subStrings;
}
@Override
public NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -32,8 +32,11 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
@ -157,8 +160,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
}
public void testGetTruncatedModelDefinition() throws Exception {
String modelId = "test-get-truncated-model-config";
public void testGetTruncatedModelDeprecatedDefinition() throws Exception {
String modelId = "test-get-truncated-legacy-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@ -196,6 +199,51 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
}
public void testGetTruncatedModelDefinition() throws Exception {
String modelId = "test-get-truncated-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId);
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
List<String> chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3);
List<TrainedModelDefinitionDoc.Builder> docBuilders = IntStream.range(0, chunks.size())
.mapToObj(i -> new TrainedModelDefinitionDoc.Builder()
.setDocNum(i)
.setCompressedString(chunks.get(i))
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
.setDefinitionLength(chunks.get(i).length())
.setEos(i == chunks.size() - 1)
.setModelId(modelId))
.collect(Collectors.toList());
boolean missingEos = randomBoolean();
docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false);
for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) {
TrainedModelDefinitionDoc doc = docBuilders.get(i).build();
try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(),
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) {
AtomicReference<IndexResponse> putDocHolder = new AtomicReference<>();
blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setSource(xContentBuilder)
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
.execute(listener),
putDocHolder,
exceptionHolder);
assertThat(exceptionHolder.get(), is(nullValue()));
}
}
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
assertThat(getConfigHolder.get(), is(nullValue()));
assertThat(exceptionHolder.get(), is(not(nullValue())));
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
}
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")

View File

@ -8,48 +8,29 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.MultiField;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import java.time.Instant;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static java.util.stream.Collectors.toList;
public class AnalyticsResultProcessor {
@ -70,11 +51,10 @@ public class AnalyticsResultProcessor {
private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final StatsHolder statsHolder;
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final StatsPersister statsPersister;
private final ExtractedFields extractedFields;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private final ChunkedTrainedModelPersister chunkedTrainedModelPersister;
private volatile String failure;
private volatile boolean isCancelled;
@ -84,10 +64,15 @@ public class AnalyticsResultProcessor {
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.statsHolder = Objects.requireNonNull(statsHolder);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor);
this.statsPersister = Objects.requireNonNull(statsPersister);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister(
trainedModelProvider,
analytics,
auditor,
this::setAndReportFailure,
extractedFields
);
}
@Nullable
@ -166,9 +151,13 @@ public class AnalyticsResultProcessor {
phaseProgress.getProgressPercent());
statsHolder.getProgressTracker().updatePhase(phaseProgress);
}
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
if (inferenceModelBuilder != null) {
createAndIndexInferenceModel(inferenceModelBuilder);
ModelSizeInfo modelSize = result.getModelSizeInfo();
if (modelSize != null) {
chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize);
}
TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk();
if (trainedModelDefinitionChunk != null) {
chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk);
}
MemoryUsage memoryUsage = result.getMemoryUsage();
if (memoryUsage != null) {
@ -191,82 +180,6 @@ public class AnalyticsResultProcessor {
}
}
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
try {
if (latch.await(30, TimeUnit.SECONDS) == false) {
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
}
}
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
Instant createTime = Instant.now();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
TrainedModelDefinition definition = inferenceModel.build();
String dependentVariable = getDependentVariable();
List<ExtractedField> fieldNames = extractedFields.getAllFields();
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
.map(ExtractedField::getName)
.filter(f -> f.equals(dependentVariable) == false)
.collect(toList());
Map<String, String> defaultFieldMapping = fieldNames.stream()
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
return TrainedModelConfig.builder()
.setModelId(modelId)
.setCreatedBy(XPackUser.NAME)
.setVersion(Version.CURRENT)
.setCreateTime(createTime)
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
.setTags(Collections.singletonList(analytics.getId()))
.setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setEstimatedHeapMemory(definition.ramBytesUsed())
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
.setParsedDefinition(inferenceModel)
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setDefaultFieldMap(defaultFieldMapping)
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
.build();
}
private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();
}
if (analytics.getAnalysis() instanceof Regression) {
return ((Regression)analytics.getAnalysis()).getDependentVariable();
}
return null;
}
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<Boolean> storeListener = ActionListener.wrap(
aBoolean -> {
if (aBoolean == false) {
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
} else {
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
}
},
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
trainedModelConfig.getModelId()))
);
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
return latch;
}
private void setAndReportFailure(Exception e) {
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
failure = "error processing results; " + e.getMessage();

View File

@ -0,0 +1,235 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.MultiField;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static java.util.stream.Collectors.toList;
public class ChunkedTrainedModelPersister {
private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class);
private static final int STORE_TIMEOUT_SEC = 30;
private final TrainedModelProvider provider;
private final AtomicReference<String> currentModelId;
private final DataFrameAnalyticsConfig analytics;
private final DataFrameAnalyticsAuditor auditor;
private final Consumer<Exception> failureHandler;
private final ExtractedFields extractedFields;
private final AtomicBoolean readyToStoreNewModel = new AtomicBoolean(true);
public ChunkedTrainedModelPersister(TrainedModelProvider provider,
DataFrameAnalyticsConfig analytics,
DataFrameAnalyticsAuditor auditor,
Consumer<Exception> failureHandler,
ExtractedFields extractedFields) {
this.provider = provider;
this.currentModelId = new AtomicReference<>("");
this.analytics = analytics;
this.auditor = auditor;
this.failureHandler = failureHandler;
this.extractedFields = extractedFields;
}
public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
if (Strings.isNullOrEmpty(this.currentModelId.get())) {
failureHandler.accept(ExceptionsHelper.serverError(
"chunked inference model definition is attempting to be stored before trained model configuration"
));
return;
}
TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc(this.currentModelId.get());
CountDownLatch latch = storeTrainedModelDoc(trainedModelDefinitionDoc);
try {
if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
LOGGER.error("[{}] Timed out (30s) waiting for chunked inference definition to be stored", analytics.getId());
if (trainedModelDefinitionChunk.isEos()) {
this.readyToStoreNewModel.set(true);
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
this.readyToStoreNewModel.set(true);
failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for chunked inference definition to be stored"));
}
}
public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) {
if (readyToStoreNewModel.compareAndSet(true, false) == false) {
failureHandler.accept(ExceptionsHelper.serverError(
"new inference model is attempting to be stored before completion previous model storage"
));
return;
}
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize);
CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig);
try {
if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
this.readyToStoreNewModel.set(true);
failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored"));
}
}
private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) {
CountDownLatch latch = new CountDownLatch(1);
// Latch is attached to this action as it is the last one to execute.
ActionListener<RefreshResponse> refreshListener = new LatchedActionListener<>(ActionListener.wrap(
refreshed -> {
if (refreshed != null) {
LOGGER.debug(() -> new ParameterizedMessage(
"[{}] refreshed inference index after model store",
analytics.getId()
));
}
},
e -> LOGGER.warn(
new ParameterizedMessage("[{}] failed to refresh inference index after model store", analytics.getId()),
e)
), latch);
// First, store the model and refresh is necessary
ActionListener<Void> storeListener = ActionListener.wrap(
r -> {
LOGGER.debug(() -> new ParameterizedMessage(
"[{}] stored trained model definition chunk [{}] [{}]",
analytics.getId(),
trainedModelDefinitionDoc.getModelId(),
trainedModelDefinitionDoc.getDocNum()));
if (trainedModelDefinitionDoc.isEos() == false) {
refreshListener.onResponse(null);
return;
}
LOGGER.info(
"[{}] finished storing trained model with id [{}]",
analytics.getId(),
this.currentModelId.get());
auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]");
this.currentModelId.set("");
readyToStoreNewModel.set(true);
provider.refreshInferenceIndex(refreshListener);
},
e -> {
this.readyToStoreNewModel.set(true);
failureHandler.accept(ExceptionsHelper.serverError(
"error storing trained model definition chunk [{}] with id [{}]",
e,
trainedModelDefinitionDoc.getModelId(),
trainedModelDefinitionDoc.getDocNum()));
refreshListener.onResponse(null);
}
);
provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener);
return latch;
}
private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<Boolean> storeListener = ActionListener.wrap(
aBoolean -> {
if (aBoolean == false) {
LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId());
readyToStoreNewModel.set(true);
failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false"));
} else {
LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
}
},
e -> {
readyToStoreNewModel.set(true);
failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]",
e,
trainedModelConfig.getModelId()));
}
);
provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
return latch;
}
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
Instant createTime = Instant.now();
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
currentModelId.set(modelId);
List<ExtractedField> fieldNames = extractedFields.getAllFields();
String dependentVariable = getDependentVariable();
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
.map(ExtractedField::getName)
.filter(f -> f.equals(dependentVariable) == false)
.collect(toList());
Map<String, String> defaultFieldMapping = fieldNames.stream()
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
return TrainedModelConfig.builder()
.setModelId(modelId)
.setCreatedBy(XPackUser.NAME)
.setVersion(Version.CURRENT)
.setCreateTime(createTime)
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
.setTags(Collections.singletonList(analytics.getId()))
.setDescription(analytics.getDescription())
.setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setEstimatedHeapMemory(modelSize.ramBytesUsed())
.setEstimatedOperations(modelSize.numOperations())
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setDefaultFieldMap(defaultFieldMapping)
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
.build();
}
private String getDependentVariable() {
if (analytics.getAnalysis() instanceof Classification) {
return ((Classification)analytics.getAnalysis()).getDependentVariable();
}
if (analytics.getAnalysis() instanceof Regression) {
return ((Regression)analytics.getAnalysis()).getDependentVariable();
}
return null;
}
}

View File

@ -8,31 +8,27 @@ package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import java.io.IOException;
import java.util.Collections;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
public class AnalyticsResult implements ToXContentObject {
public static final ParseField TYPE = new ParseField("analytics_result");
private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info");
private static final ParseField COMPRESSED_INFERENCE_MODEL = new ParseField("compressed_inference_model");
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
@ -42,53 +38,50 @@ public class AnalyticsResult implements ToXContentObject {
a -> new AnalyticsResult(
(RowResults) a[0],
(PhaseProgress) a[1],
(TrainedModelDefinition.Builder) a[2],
(MemoryUsage) a[3],
(OutlierDetectionStats) a[4],
(ClassificationStats) a[5],
(RegressionStats) a[6],
(ModelSizeInfo) a[7]
(MemoryUsage) a[2],
(OutlierDetectionStats) a[3],
(ClassificationStats) a[4],
(RegressionStats) a[5],
(ModelSizeInfo) a[6],
(TrainedModelDefinitionChunk) a[7]
));
static {
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS);
// TODO change back to STRICT_PARSER once native side is aligned
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE);
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL);
}
private final RowResults rowResults;
private final PhaseProgress phaseProgress;
private final TrainedModelDefinition.Builder inferenceModelBuilder;
private final TrainedModelDefinition inferenceModel;
private final MemoryUsage memoryUsage;
private final OutlierDetectionStats outlierDetectionStats;
private final ClassificationStats classificationStats;
private final RegressionStats regressionStats;
private final ModelSizeInfo modelSizeInfo;
private final TrainedModelDefinitionChunk trainedModelDefinitionChunk;
public AnalyticsResult(@Nullable RowResults rowResults,
@Nullable PhaseProgress phaseProgress,
@Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
@Nullable MemoryUsage memoryUsage,
@Nullable OutlierDetectionStats outlierDetectionStats,
@Nullable ClassificationStats classificationStats,
@Nullable RegressionStats regressionStats,
@Nullable ModelSizeInfo modelSizeInfo) {
@Nullable ModelSizeInfo modelSizeInfo,
@Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
this.rowResults = rowResults;
this.phaseProgress = phaseProgress;
this.inferenceModelBuilder = inferenceModelBuilder;
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
this.memoryUsage = memoryUsage;
this.outlierDetectionStats = outlierDetectionStats;
this.classificationStats = classificationStats;
this.regressionStats = regressionStats;
this.modelSizeInfo = modelSizeInfo;
this.trainedModelDefinitionChunk = trainedModelDefinitionChunk;
}
public RowResults getRowResults() {
@ -99,10 +92,6 @@ public class AnalyticsResult implements ToXContentObject {
return phaseProgress;
}
public TrainedModelDefinition.Builder getInferenceModelBuilder() {
return inferenceModelBuilder;
}
public MemoryUsage getMemoryUsage() {
return memoryUsage;
}
@ -123,6 +112,10 @@ public class AnalyticsResult implements ToXContentObject {
return modelSizeInfo;
}
public TrainedModelDefinitionChunk getTrainedModelDefinitionChunk() {
return trainedModelDefinitionChunk;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
@ -132,11 +125,6 @@ public class AnalyticsResult implements ToXContentObject {
if (phaseProgress != null) {
builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress);
}
if (inferenceModel != null) {
builder.field(INFERENCE_MODEL.getPreferredName(),
inferenceModel,
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")));
}
if (memoryUsage != null) {
builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params);
}
@ -152,6 +140,9 @@ public class AnalyticsResult implements ToXContentObject {
if (modelSizeInfo != null) {
builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo);
}
if (trainedModelDefinitionChunk != null) {
builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk);
}
builder.endObject();
return builder;
}
@ -168,17 +159,17 @@ public class AnalyticsResult implements ToXContentObject {
AnalyticsResult that = (AnalyticsResult) other;
return Objects.equals(rowResults, that.rowResults)
&& Objects.equals(phaseProgress, that.phaseProgress)
&& Objects.equals(inferenceModel, that.inferenceModel)
&& Objects.equals(memoryUsage, that.memoryUsage)
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
&& Objects.equals(classificationStats, that.classificationStats)
&& Objects.equals(modelSizeInfo, that.modelSizeInfo)
&& Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk)
&& Objects.equals(regressionStats, that.regressionStats);
}
@Override
public int hashCode() {
return Objects.hash(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
regressionStats);
return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats,
regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
}
}

View File

@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class TrainedModelDefinitionChunk implements ToXContentObject {
private static final ParseField DEFINITION = new ParseField("definition");
private static final ParseField DOC_NUM = new ParseField("doc_num");
private static final ParseField EOS = new ParseField("eos");
public static final ConstructingObjectParser<TrainedModelDefinitionChunk, Void> PARSER = new ConstructingObjectParser<>(
"chunked_trained_model_definition",
a -> new TrainedModelDefinitionChunk((String) a[0], (Integer) a[1], (Boolean) a[2]));
static {
PARSER.declareString(constructorArg(), DEFINITION);
PARSER.declareInt(constructorArg(), DOC_NUM);
PARSER.declareBoolean(optionalConstructorArg(), EOS);
}
private final String definition;
private final int docNum;
private final Boolean eos;
public TrainedModelDefinitionChunk(String definition, int docNum, Boolean eos) {
this.definition = definition;
this.docNum = docNum;
this.eos = eos;
}
public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId) {
return new TrainedModelDefinitionDoc.Builder()
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
.setModelId(modelId)
.setDefinitionLength(definition.length())
.setDocNum(docNum)
.setCompressedString(definition)
.setEos(isEos())
.build();
}
public boolean isEos() {
return eos != null && eos;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DEFINITION.getPreferredName(), definition);
builder.field(DOC_NUM.getPreferredName(), docNum);
if (eos != null) {
builder.field(EOS.getPreferredName(), eos);
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinitionChunk that = (TrainedModelDefinitionChunk) o;
return docNum == that.docNum
&& Objects.equals(definition, that.definition)
&& Objects.equals(eos, that.eos);
}
@Override
public int hashCode() {
return Objects.hash(definition, docNum, eos);
}
}

View File

@ -33,6 +33,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version");
public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length");
public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length");
public static final ParseField EOS = new ParseField("eos");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelDefinitionDoc.Builder, Void> LENIENT_PARSER = createParser(true);
@ -48,6 +49,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION);
parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH);
parser.declareLong(TrainedModelDefinitionDoc.Builder::setTotalDefinitionLength, TOTAL_DEFINITION_LENGTH);
parser.declareBoolean(TrainedModelDefinitionDoc.Builder::setEos, EOS);
return parser;
}
@ -63,23 +65,26 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
private final String compressedString;
private final String modelId;
private final int docNum;
private final long totalDefinitionLength;
// for BWC
private final Long totalDefinitionLength;
private final long definitionLength;
private final int compressionVersion;
private final boolean eos;
private TrainedModelDefinitionDoc(String compressedString,
String modelId,
int docNum,
long totalDefinitionLength,
Long totalDefinitionLength,
long definitionLength,
int compressionVersion) {
int compressionVersion,
boolean eos) {
this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION);
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
if (docNum < 0) {
throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0");
}
this.docNum = docNum;
if (totalDefinitionLength <= 0L) {
if (totalDefinitionLength != null && totalDefinitionLength <= 0L) {
throw new IllegalArgumentException("[total_definition_length] must be greater than 0");
}
this.totalDefinitionLength = totalDefinitionLength;
@ -88,6 +93,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
}
this.definitionLength = definitionLength;
this.compressionVersion = compressionVersion;
this.eos = eos;
}
public String getCompressedString() {
@ -102,7 +108,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
return docNum;
}
public long getTotalDefinitionLength() {
public Long getTotalDefinitionLength() {
return totalDefinitionLength;
}
@ -114,16 +120,24 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
return compressionVersion;
}
public boolean isEos() {
return eos;
}
public String getDocId() {
return docId(modelId, docNum);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
builder.field(DOC_NUM.getPreferredName(), docNum);
builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength);
builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength);
builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion);
builder.field(DEFINITION.getPreferredName(), compressedString);
builder.field(EOS.getPreferredName(), eos);
builder.endObject();
return builder;
}
@ -143,12 +157,13 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
Objects.equals(definitionLength, that.definitionLength) &&
Objects.equals(totalDefinitionLength, that.totalDefinitionLength) &&
Objects.equals(compressionVersion, that.compressionVersion) &&
Objects.equals(eos, that.eos) &&
Objects.equals(compressedString, that.compressedString);
}
@Override
public int hashCode() {
return Objects.hash(modelId, docNum, totalDefinitionLength, definitionLength, compressionVersion, compressedString);
return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos);
}
public static class Builder {
@ -156,9 +171,10 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
private String modelId;
private String compressedString;
private int docNum;
private long totalDefinitionLength;
private Long totalDefinitionLength;
private long definitionLength;
private int compressionVersion;
private boolean eos;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -190,6 +206,11 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
return this;
}
public Builder setEos(boolean eos) {
this.eos = eos;
return this;
}
public TrainedModelDefinitionDoc build() {
return new TrainedModelDefinitionDoc(
this.compressedString,
@ -197,7 +218,8 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
this.docNum,
this.totalDefinitionLength,
this.definitionLength,
this.compressionVersion);
this.compressionVersion,
this.eos);
}
}

View File

@ -14,10 +14,14 @@ import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.MultiSearchAction;
import org.elasticsearch.action.search.MultiSearchRequest;
@ -143,6 +147,74 @@ public class TrainedModelProvider {
storeTrainedModelAndDefinition(trainedModelConfig, listener);
}
public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig,
ActionListener<Boolean> listener) {
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
return;
}
assert trainedModelConfig.getModelDefinition() == null;
executeAsyncWithOrigin(client,
ML_ORIGIN,
IndexAction.INSTANCE,
createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig),
ActionListener.wrap(
indexResponse -> listener.onResponse(true),
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
} else {
listener.onFailure(
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
RestStatus.INTERNAL_SERVER_ERROR,
e,
trainedModelConfig.getModelId()));
}
}
));
}
public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc, ActionListener<Void> listener) {
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelDefinitionDoc.getModelId())) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelDefinitionDoc.getModelId())));
return;
}
executeAsyncWithOrigin(client,
ML_ORIGIN,
IndexAction.INSTANCE,
createRequest(trainedModelDefinitionDoc.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelDefinitionDoc),
ActionListener.wrap(
indexResponse -> listener.onResponse(null),
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
listener.onFailure(new ResourceAlreadyExistsException(
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_DOC_EXISTS,
trainedModelDefinitionDoc.getModelId(),
trainedModelDefinitionDoc.getDocNum())));
} else {
listener.onFailure(
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
RestStatus.INTERNAL_SERVER_ERROR,
e,
trainedModelDefinitionDoc.getModelId()));
}
}
));
}
public void refreshInferenceIndex(ActionListener<RefreshResponse> listener) {
executeAsyncWithOrigin(client,
ML_ORIGIN,
RefreshAction.INSTANCE,
new RefreshRequest(InferenceIndexConstants.INDEX_PATTERN),
listener);
}
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
ActionListener<Boolean> listener) {
@ -165,7 +237,8 @@ public class TrainedModelProvider {
.setCompressedString(chunkedStrings.get(i))
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
.setDefinitionLength(chunkedStrings.get(i).length())
.setTotalDefinitionLength(compressedString.length())
// If it is the last doc, it is the EOS
.setEos(i == chunkedStrings.size() - 1)
.build());
}
} catch (IOException ex) {
@ -265,6 +338,9 @@ public class TrainedModelProvider {
.unmappedType("long"))
.request();
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
// TODO how could we stream in the model definition WHILE parsing it?
// This would reduce the overall memory usage as we won't have to load the whole compressed string
// XContentParser supports streams.
searchResponse -> {
if (searchResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(
@ -274,19 +350,16 @@ public class TrainedModelProvider {
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
modelId,
this::parseModelDefinitionDocLenientlyFromSource);
String compressedString = docs.stream()
.map(TrainedModelDefinitionDoc::getCompressedString)
.collect(Collectors.joining());
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
listener.onFailure(ExceptionsHelper.serverError(
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
return;
try {
String compressedString = getDefinitionFromDocs(docs, modelId);
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
compressedString,
InferenceDefinition::fromXContent,
xContentRegistry);
listener.onResponse(inferenceDefinition);
} catch (ElasticsearchException elasticsearchException) {
listener.onFailure(elasticsearchException);
}
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
compressedString,
InferenceDefinition::fromXContent,
xContentRegistry);
listener.onResponse(inferenceDefinition);
},
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
@ -361,15 +434,14 @@ public class TrainedModelProvider {
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
modelId,
this::parseModelDefinitionDocLenientlyFromSource);
String compressedString = docs.stream()
.map(TrainedModelDefinitionDoc::getCompressedString)
.collect(Collectors.joining());
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
listener.onFailure(ExceptionsHelper.serverError(
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
try {
String compressedString = getDefinitionFromDocs(docs, modelId);
builder.setDefinitionFromString(compressedString);
} catch (ElasticsearchException elasticsearchException) {
listener.onFailure(elasticsearchException);
return;
}
builder.setDefinitionFromString(compressedString);
} catch (ResourceNotFoundException ex) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
@ -806,6 +878,26 @@ public class TrainedModelProvider {
return results;
}
private static String getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId) throws ElasticsearchException {
String compressedString = docs.stream()
.map(TrainedModelDefinitionDoc::getCompressedString)
.collect(Collectors.joining());
// BWC for when we tracked the total definition length
// TODO: remove in 9
if (docs.get(0).getTotalDefinitionLength() != null) {
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId));
}
} else {
TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1);
// Either we are missing the last doc, or some previous doc
if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) {
throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId));
}
}
return compressedString;
}
static List<String> chunkStringWithSize(String str, int chunkSize) {
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
for (int i = 0; i < str.length();i += chunkSize) {
@ -836,14 +928,18 @@ public class TrainedModelProvider {
}
}
private IndexRequest createRequest(String docId, String index, ToXContentObject body) {
return createRequest(new IndexRequest(index), docId, body);
}
private IndexRequest createRequest(String docId, ToXContentObject body) {
return createRequest(new IndexRequest(), docId, body);
}
private IndexRequest createRequest(IndexRequest request, String docId, ToXContentObject body) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
return new IndexRequest()
.opType(DocWriteRequest.OpType.CREATE)
.id(docId)
.source(source);
return request.opType(DocWriteRequest.OpType.CREATE).id(docId).source(source);
} catch (IOException ex) {
// This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again
// that is not the users fault. We did something wrong and should throw.

View File

@ -5,32 +5,20 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.MultiField;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
@ -38,20 +26,14 @@ import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -156,90 +138,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
}
@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
givenDataFrameRows(0);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
storeListener.onResponse(true);
return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
List<ExtractedField> extractedFieldList = new ArrayList<>(3);
extractedFieldList.add(new DocValueField("foo", Collections.emptySet()));
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
resultProcessor.process(process);
resultProcessor.awaitForCompletion();
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
TrainedModelConfig storedModel = storedModelCaptor.getValue();
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
assertThat(storedModel.getModelId(), containsString(JOB_ID));
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME));
assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword")));
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz")));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
if (targetType.equals(TargetType.CLASSIFICATION)) {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
} else {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
}
Map<String, Object> metadata = storedModel.getMetadata();
assertThat(metadata.size(), equalTo(1));
assertThat(metadata, hasKey("analytics_config"));
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
true);
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
Mockito.verifyNoMoreInteractions(auditor);
}
@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelFailedToStore() {
givenDataFrameRows(0);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
storeListener.onFailure(new RuntimeException("some failure"));
return null;
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
resultProcessor.process(process);
resultProcessor.awaitForCompletion();
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID));
Mockito.verifyNoMoreInteractions(auditor);
assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
}
private void givenProcessResults(List<AnalyticsResult> results) {
when(process.readAnalyticsResults()).thenReturn(results.iterator());
}
@ -256,7 +154,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
}
private AnalyticsResultProcessor createResultProcessor(List<ExtractedField> fieldNames) {
return new AnalyticsResultProcessor(analyticsConfig,
dataFrameRowsJoiner,
statsHolder,

View File

@ -0,0 +1,150 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.License;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class ChunkedTrainedModelPersisterTests extends ESTestCase {
private static final String JOB_ID = "analytics-result-processor-tests";
private static final String JOB_DESCRIPTION = "This describes the job of these tests";
private TrainedModelProvider trainedModelProvider;
private DataFrameAnalyticsAuditor auditor;
@Before
public void setUpMocks() {
trainedModelProvider = mock(TrainedModelProvider.class);
auditor = mock(DataFrameAnalyticsAuditor.class);
}
@SuppressWarnings("unchecked")
public void testPersistAllDocs() {
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
.setId(JOB_ID)
.setDescription(JOB_DESCRIPTION)
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
.setDest(new DataFrameAnalyticsDest("my_dest", null))
.setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo"))
.build();
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
doAnswer(invocationOnMock -> {
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
storeListener.onResponse(true);
return null;
}).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class));
doAnswer(invocationOnMock -> {
ActionListener<Void> storeListener = (ActionListener<Void>) invocationOnMock.getArguments()[1];
storeListener.onResponse(null);
return null;
}).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class));
ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig);
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false);
TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true);
resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo);
resultProcessor.createAndIndexInferenceModelDoc(chunk1);
resultProcessor.createAndIndexInferenceModelDoc(chunk2);
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class));
ArgumentCaptor<TrainedModelDefinitionDoc> storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class);
verify(trainedModelProvider, times(2))
.storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class));
TrainedModelConfig storedModel = storedModelCaptor.getValue();
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
assertThat(storedModel.getModelId(), containsString(JOB_ID));
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME));
assertThat(storedModel.getTags(), contains(JOB_ID));
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getModelDefinition(), is(nullValue()));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
if (analyticsConfig.getAnalysis() instanceof Classification) {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
} else {
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
}
Map<String, Object> metadata = storedModel.getMetadata();
assertThat(metadata.size(), equalTo(1));
assertThat(metadata, hasKey("analytics_config"));
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
true);
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
TrainedModelDefinitionDoc storedDoc1 = storedDocCapture.getAllValues().get(0);
assertThat(storedDoc1.getDocNum(), equalTo(0));
TrainedModelDefinitionDoc storedDoc2 = storedDocCapture.getAllValues().get(1);
assertThat(storedDoc2.getDocNum(), equalTo(1));
assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId()));
assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId()));
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
Mockito.verifyNoMoreInteractions(auditor);
}
private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List<ExtractedField> fieldNames,
DataFrameAnalyticsConfig analyticsConfig) {
return new ChunkedTrainedModelPersister(trainedModelProvider,
analyticsConfig,
auditor,
(unused)->{},
new ExtractedFields(fieldNames, Collections.emptyMap()));
}
}

View File

@ -20,8 +20,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
@ -46,21 +44,18 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
protected AnalyticsResult createTestInstance() {
RowResults rowResults = null;
PhaseProgress phaseProgress = null;
TrainedModelDefinition.Builder inferenceModel = null;
MemoryUsage memoryUsage = null;
OutlierDetectionStats outlierDetectionStats = null;
ClassificationStats classificationStats = null;
RegressionStats regressionStats = null;
ModelSizeInfo modelSizeInfo = null;
TrainedModelDefinitionChunk trainedModelDefinitionChunk = null;
if (randomBoolean()) {
rowResults = RowResultsTests.createRandom();
}
if (randomBoolean()) {
phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
}
if (randomBoolean()) {
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
}
if (randomBoolean()) {
memoryUsage = MemoryUsageTests.createRandom();
}
@ -76,8 +71,12 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
if (randomBoolean()) {
modelSizeInfo = ModelSizeInfoTests.createRandom();
}
return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats,
classificationStats, regressionStats, modelSizeInfo);
if (randomBoolean()) {
String def = randomAlphaOfLengthBetween(100, 1000);
trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean());
}
return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats,
classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
}
@Override