* [ML] handles compressed model stream from native process (#58009) This moves model storage from handling the fully parsed JSON string to handling two separate types of documents. 1. ModelSizeInfo which contains model size information 2. TrainedModelDefinitionChunk which contains a particular chunk of the compressed model definition string. `model_size_info` is assumed to be handled first. This will generate the model_id and store the initial trained model config object. Then each chunk is assumed to be in correct order for concatenating the chunks to get a compressed definition. Native side change: https://github.com/elastic/ml-cpp/pull/1349
This commit is contained in:
parent
9c77862a23
commit
c64e283dbf
|
@ -89,6 +89,7 @@ public final class Messages {
|
|||
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
|
||||
|
||||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
||||
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
|
||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
||||
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
||||
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionModule;
|
||||
|
@ -67,7 +66,6 @@ import static org.hamcrest.Matchers.not;
|
|||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
|
||||
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
|
||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String BOOLEAN_FIELD = "boolean-field";
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionModule;
|
||||
import org.elasticsearch.action.DocWriteRequest;
|
||||
|
@ -45,7 +44,6 @@ import static org.hamcrest.Matchers.is;
|
|||
import static org.hamcrest.Matchers.lessThan;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
|
||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
||||
|
||||
private TrainedModelProvider trainedModelProvider;
|
||||
|
||||
@Before
|
||||
public void createComponents() throws Exception {
|
||||
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
|
||||
waitForMlTemplates();
|
||||
}
|
||||
|
||||
public void testStoreModelViaChunkedPersister() throws IOException {
|
||||
String modelId = "stored-chunked-model";
|
||||
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(modelId)
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
|
||||
.setDest(new DataFrameAnalyticsDest("my_dest", null))
|
||||
.setAnalysis(new Regression("foo"))
|
||||
.build();
|
||||
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
|
||||
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
|
||||
String compressedDefinition = configBuilder.build().getCompressedDefinition();
|
||||
int totalSize = compressedDefinition.length();
|
||||
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);
|
||||
|
||||
ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
|
||||
analyticsConfig,
|
||||
new DataFrameAnalyticsAuditor(client(), "test-node"),
|
||||
(ex) -> { throw new ElasticsearchException(ex); },
|
||||
new ExtractedFields(extractedFieldList, Collections.emptyMap())
|
||||
);
|
||||
|
||||
//Accuracy for size is not tested here
|
||||
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
|
||||
for (int i = 0; i < chunks.size(); i++) {
|
||||
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
|
||||
}
|
||||
|
||||
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
|
||||
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
|
||||
assertThat(ids.v1(), equalTo(1L));
|
||||
|
||||
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
|
||||
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
|
||||
|
||||
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
|
||||
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
|
||||
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||
}
|
||||
|
||||
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
|
||||
long bytesUsed = definitionBuilder.build().ramBytesUsed();
|
||||
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
|
||||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
|
||||
.setDescription("trained model config for test")
|
||||
.setModelId(modelId)
|
||||
.setVersion(Version.CURRENT)
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setEstimatedHeapMemory(bytesUsed)
|
||||
.setEstimatedOperations(operations)
|
||||
.setInput(TrainedModelInputTests.createRandomInput());
|
||||
}
|
||||
|
||||
public static List<String> chunkStringWithSize(String str, int chunkSize) {
|
||||
List<String> subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize);
|
||||
for (int i = 0; i < str.length(); i += chunkSize) {
|
||||
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
|
||||
}
|
||||
return subStrings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NamedXContentRegistry xContentRegistry() {
|
||||
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
|
||||
return new NamedXContentRegistry(namedXContent);
|
||||
}
|
||||
|
||||
}
|
|
@ -32,8 +32,11 @@ import java.util.ArrayList;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||
import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
@ -157,8 +160,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
}
|
||||
|
||||
public void testGetTruncatedModelDefinition() throws Exception {
|
||||
String modelId = "test-get-truncated-model-config";
|
||||
public void testGetTruncatedModelDeprecatedDefinition() throws Exception {
|
||||
String modelId = "test-get-truncated-legacy-model-config";
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
@ -196,6 +199,51 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
|||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||
}
|
||||
|
||||
public void testGetTruncatedModelDefinition() throws Exception {
|
||||
String modelId = "test-get-truncated-model-config";
|
||||
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||
|
||||
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||
assertThat(putConfigHolder.get(), is(true));
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
|
||||
List<String> chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3);
|
||||
|
||||
List<TrainedModelDefinitionDoc.Builder> docBuilders = IntStream.range(0, chunks.size())
|
||||
.mapToObj(i -> new TrainedModelDefinitionDoc.Builder()
|
||||
.setDocNum(i)
|
||||
.setCompressedString(chunks.get(i))
|
||||
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
||||
.setDefinitionLength(chunks.get(i).length())
|
||||
.setEos(i == chunks.size() - 1)
|
||||
.setModelId(modelId))
|
||||
.collect(Collectors.toList());
|
||||
boolean missingEos = randomBoolean();
|
||||
docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false);
|
||||
for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) {
|
||||
TrainedModelDefinitionDoc doc = docBuilders.get(i).build();
|
||||
try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(),
|
||||
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) {
|
||||
AtomicReference<IndexResponse> putDocHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||
.setSource(xContentBuilder)
|
||||
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
|
||||
.execute(listener),
|
||||
putDocHolder,
|
||||
exceptionHolder);
|
||||
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||
}
|
||||
}
|
||||
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||
}
|
||||
|
||||
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||
return TrainedModelConfig.builder()
|
||||
.setCreatedBy("ml_test")
|
||||
|
|
|
@ -8,48 +8,29 @@ package org.elasticsearch.xpack.ml.dataframe.process;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.LatchedActionListener;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class AnalyticsResultProcessor {
|
||||
|
||||
|
@ -70,11 +51,10 @@ public class AnalyticsResultProcessor {
|
|||
private final DataFrameAnalyticsConfig analytics;
|
||||
private final DataFrameRowsJoiner dataFrameRowsJoiner;
|
||||
private final StatsHolder statsHolder;
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
private final DataFrameAnalyticsAuditor auditor;
|
||||
private final StatsPersister statsPersister;
|
||||
private final ExtractedFields extractedFields;
|
||||
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
||||
private final ChunkedTrainedModelPersister chunkedTrainedModelPersister;
|
||||
private volatile String failure;
|
||||
private volatile boolean isCancelled;
|
||||
|
||||
|
@ -84,10 +64,15 @@ public class AnalyticsResultProcessor {
|
|||
this.analytics = Objects.requireNonNull(analytics);
|
||||
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
||||
this.statsHolder = Objects.requireNonNull(statsHolder);
|
||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||
this.auditor = Objects.requireNonNull(auditor);
|
||||
this.statsPersister = Objects.requireNonNull(statsPersister);
|
||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
||||
this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister(
|
||||
trainedModelProvider,
|
||||
analytics,
|
||||
auditor,
|
||||
this::setAndReportFailure,
|
||||
extractedFields
|
||||
);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
|
@ -166,9 +151,13 @@ public class AnalyticsResultProcessor {
|
|||
phaseProgress.getProgressPercent());
|
||||
statsHolder.getProgressTracker().updatePhase(phaseProgress);
|
||||
}
|
||||
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
|
||||
if (inferenceModelBuilder != null) {
|
||||
createAndIndexInferenceModel(inferenceModelBuilder);
|
||||
ModelSizeInfo modelSize = result.getModelSizeInfo();
|
||||
if (modelSize != null) {
|
||||
chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize);
|
||||
}
|
||||
TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk();
|
||||
if (trainedModelDefinitionChunk != null) {
|
||||
chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk);
|
||||
}
|
||||
MemoryUsage memoryUsage = result.getMemoryUsage();
|
||||
if (memoryUsage != null) {
|
||||
|
@ -191,82 +180,6 @@ public class AnalyticsResultProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
|
||||
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
|
||||
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
|
||||
|
||||
try {
|
||||
if (latch.await(30, TimeUnit.SECONDS) == false) {
|
||||
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
|
||||
}
|
||||
}
|
||||
|
||||
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
|
||||
Instant createTime = Instant.now();
|
||||
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||
TrainedModelDefinition definition = inferenceModel.build();
|
||||
String dependentVariable = getDependentVariable();
|
||||
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
||||
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
|
||||
.map(ExtractedField::getName)
|
||||
.filter(f -> f.equals(dependentVariable) == false)
|
||||
.collect(toList());
|
||||
Map<String, String> defaultFieldMapping = fieldNames.stream()
|
||||
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
|
||||
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
|
||||
return TrainedModelConfig.builder()
|
||||
.setModelId(modelId)
|
||||
.setCreatedBy(XPackUser.NAME)
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(createTime)
|
||||
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
|
||||
.setTags(Collections.singletonList(analytics.getId()))
|
||||
.setDescription(analytics.getDescription())
|
||||
.setMetadata(Collections.singletonMap("analytics_config",
|
||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||
.setEstimatedHeapMemory(definition.ramBytesUsed())
|
||||
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
|
||||
.setParsedDefinition(inferenceModel)
|
||||
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setDefaultFieldMap(defaultFieldMapping)
|
||||
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
|
||||
.build();
|
||||
}
|
||||
|
||||
private String getDependentVariable() {
|
||||
if (analytics.getAnalysis() instanceof Classification) {
|
||||
return ((Classification)analytics.getAnalysis()).getDependentVariable();
|
||||
}
|
||||
if (analytics.getAnalysis() instanceof Regression) {
|
||||
return ((Regression)analytics.getAnalysis()).getDependentVariable();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
ActionListener<Boolean> storeListener = ActionListener.wrap(
|
||||
aBoolean -> {
|
||||
if (aBoolean == false) {
|
||||
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
|
||||
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
|
||||
} else {
|
||||
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
|
||||
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
|
||||
}
|
||||
},
|
||||
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
|
||||
trainedModelConfig.getModelId()))
|
||||
);
|
||||
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
|
||||
return latch;
|
||||
}
|
||||
|
||||
private void setAndReportFailure(Exception e) {
|
||||
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
|
||||
failure = "error processing results; " + e.getMessage();
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.LatchedActionListener;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
public class ChunkedTrainedModelPersister {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class);
|
||||
private static final int STORE_TIMEOUT_SEC = 30;
|
||||
private final TrainedModelProvider provider;
|
||||
private final AtomicReference<String> currentModelId;
|
||||
private final DataFrameAnalyticsConfig analytics;
|
||||
private final DataFrameAnalyticsAuditor auditor;
|
||||
private final Consumer<Exception> failureHandler;
|
||||
private final ExtractedFields extractedFields;
|
||||
private final AtomicBoolean readyToStoreNewModel = new AtomicBoolean(true);
|
||||
|
||||
public ChunkedTrainedModelPersister(TrainedModelProvider provider,
|
||||
DataFrameAnalyticsConfig analytics,
|
||||
DataFrameAnalyticsAuditor auditor,
|
||||
Consumer<Exception> failureHandler,
|
||||
ExtractedFields extractedFields) {
|
||||
this.provider = provider;
|
||||
this.currentModelId = new AtomicReference<>("");
|
||||
this.analytics = analytics;
|
||||
this.auditor = auditor;
|
||||
this.failureHandler = failureHandler;
|
||||
this.extractedFields = extractedFields;
|
||||
}
|
||||
|
||||
public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
|
||||
if (Strings.isNullOrEmpty(this.currentModelId.get())) {
|
||||
failureHandler.accept(ExceptionsHelper.serverError(
|
||||
"chunked inference model definition is attempting to be stored before trained model configuration"
|
||||
));
|
||||
return;
|
||||
}
|
||||
TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc(this.currentModelId.get());
|
||||
|
||||
CountDownLatch latch = storeTrainedModelDoc(trainedModelDefinitionDoc);
|
||||
try {
|
||||
if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
|
||||
LOGGER.error("[{}] Timed out (30s) waiting for chunked inference definition to be stored", analytics.getId());
|
||||
if (trainedModelDefinitionChunk.isEos()) {
|
||||
this.readyToStoreNewModel.set(true);
|
||||
}
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
this.readyToStoreNewModel.set(true);
|
||||
failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for chunked inference definition to be stored"));
|
||||
}
|
||||
}
|
||||
|
||||
public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) {
|
||||
if (readyToStoreNewModel.compareAndSet(true, false) == false) {
|
||||
failureHandler.accept(ExceptionsHelper.serverError(
|
||||
"new inference model is attempting to be stored before completion previous model storage"
|
||||
));
|
||||
return;
|
||||
}
|
||||
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize);
|
||||
CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig);
|
||||
try {
|
||||
if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
|
||||
LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId());
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
this.readyToStoreNewModel.set(true);
|
||||
failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored"));
|
||||
}
|
||||
}
|
||||
|
||||
private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
|
||||
// Latch is attached to this action as it is the last one to execute.
|
||||
ActionListener<RefreshResponse> refreshListener = new LatchedActionListener<>(ActionListener.wrap(
|
||||
refreshed -> {
|
||||
if (refreshed != null) {
|
||||
LOGGER.debug(() -> new ParameterizedMessage(
|
||||
"[{}] refreshed inference index after model store",
|
||||
analytics.getId()
|
||||
));
|
||||
}
|
||||
},
|
||||
e -> LOGGER.warn(
|
||||
new ParameterizedMessage("[{}] failed to refresh inference index after model store", analytics.getId()),
|
||||
e)
|
||||
), latch);
|
||||
|
||||
// First, store the model and refresh is necessary
|
||||
ActionListener<Void> storeListener = ActionListener.wrap(
|
||||
r -> {
|
||||
LOGGER.debug(() -> new ParameterizedMessage(
|
||||
"[{}] stored trained model definition chunk [{}] [{}]",
|
||||
analytics.getId(),
|
||||
trainedModelDefinitionDoc.getModelId(),
|
||||
trainedModelDefinitionDoc.getDocNum()));
|
||||
if (trainedModelDefinitionDoc.isEos() == false) {
|
||||
refreshListener.onResponse(null);
|
||||
return;
|
||||
}
|
||||
LOGGER.info(
|
||||
"[{}] finished storing trained model with id [{}]",
|
||||
analytics.getId(),
|
||||
this.currentModelId.get());
|
||||
auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]");
|
||||
this.currentModelId.set("");
|
||||
readyToStoreNewModel.set(true);
|
||||
provider.refreshInferenceIndex(refreshListener);
|
||||
},
|
||||
e -> {
|
||||
this.readyToStoreNewModel.set(true);
|
||||
failureHandler.accept(ExceptionsHelper.serverError(
|
||||
"error storing trained model definition chunk [{}] with id [{}]",
|
||||
e,
|
||||
trainedModelDefinitionDoc.getModelId(),
|
||||
trainedModelDefinitionDoc.getDocNum()));
|
||||
refreshListener.onResponse(null);
|
||||
}
|
||||
);
|
||||
provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener);
|
||||
return latch;
|
||||
}
|
||||
private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
ActionListener<Boolean> storeListener = ActionListener.wrap(
|
||||
aBoolean -> {
|
||||
if (aBoolean == false) {
|
||||
LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId());
|
||||
readyToStoreNewModel.set(true);
|
||||
failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false"));
|
||||
} else {
|
||||
LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
|
||||
}
|
||||
},
|
||||
e -> {
|
||||
readyToStoreNewModel.set(true);
|
||||
failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]",
|
||||
e,
|
||||
trainedModelConfig.getModelId()));
|
||||
}
|
||||
);
|
||||
provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
|
||||
return latch;
|
||||
}
|
||||
|
||||
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
|
||||
Instant createTime = Instant.now();
|
||||
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||
currentModelId.set(modelId);
|
||||
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
||||
String dependentVariable = getDependentVariable();
|
||||
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
|
||||
.map(ExtractedField::getName)
|
||||
.filter(f -> f.equals(dependentVariable) == false)
|
||||
.collect(toList());
|
||||
Map<String, String> defaultFieldMapping = fieldNames.stream()
|
||||
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
|
||||
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
|
||||
return TrainedModelConfig.builder()
|
||||
.setModelId(modelId)
|
||||
.setCreatedBy(XPackUser.NAME)
|
||||
.setVersion(Version.CURRENT)
|
||||
.setCreateTime(createTime)
|
||||
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
|
||||
.setTags(Collections.singletonList(analytics.getId()))
|
||||
.setDescription(analytics.getDescription())
|
||||
.setMetadata(Collections.singletonMap("analytics_config",
|
||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||
.setEstimatedHeapMemory(modelSize.ramBytesUsed())
|
||||
.setEstimatedOperations(modelSize.numOperations())
|
||||
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||
.setDefaultFieldMap(defaultFieldMapping)
|
||||
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
|
||||
.build();
|
||||
}
|
||||
|
||||
private String getDependentVariable() {
|
||||
if (analytics.getAnalysis() instanceof Classification) {
|
||||
return ((Classification)analytics.getAnalysis()).getDependentVariable();
|
||||
}
|
||||
if (analytics.getAnalysis() instanceof Regression) {
|
||||
return ((Regression)analytics.getAnalysis()).getDependentVariable();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
|
@ -8,31 +8,27 @@ package org.elasticsearch.xpack.ml.dataframe.process.results;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||
|
||||
public class AnalyticsResult implements ToXContentObject {
|
||||
|
||||
public static final ParseField TYPE = new ParseField("analytics_result");
|
||||
|
||||
private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
|
||||
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
||||
private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info");
|
||||
private static final ParseField COMPRESSED_INFERENCE_MODEL = new ParseField("compressed_inference_model");
|
||||
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
|
||||
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
|
||||
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
|
||||
|
@ -42,53 +38,50 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
a -> new AnalyticsResult(
|
||||
(RowResults) a[0],
|
||||
(PhaseProgress) a[1],
|
||||
(TrainedModelDefinition.Builder) a[2],
|
||||
(MemoryUsage) a[3],
|
||||
(OutlierDetectionStats) a[4],
|
||||
(ClassificationStats) a[5],
|
||||
(RegressionStats) a[6],
|
||||
(ModelSizeInfo) a[7]
|
||||
(MemoryUsage) a[2],
|
||||
(OutlierDetectionStats) a[3],
|
||||
(ClassificationStats) a[4],
|
||||
(RegressionStats) a[5],
|
||||
(ModelSizeInfo) a[6],
|
||||
(TrainedModelDefinitionChunk) a[7]
|
||||
));
|
||||
|
||||
static {
|
||||
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
||||
PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS);
|
||||
// TODO change back to STRICT_PARSER once native side is aligned
|
||||
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
|
||||
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE);
|
||||
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
|
||||
PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
|
||||
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL);
|
||||
}
|
||||
|
||||
private final RowResults rowResults;
|
||||
private final PhaseProgress phaseProgress;
|
||||
private final TrainedModelDefinition.Builder inferenceModelBuilder;
|
||||
private final TrainedModelDefinition inferenceModel;
|
||||
private final MemoryUsage memoryUsage;
|
||||
private final OutlierDetectionStats outlierDetectionStats;
|
||||
private final ClassificationStats classificationStats;
|
||||
private final RegressionStats regressionStats;
|
||||
private final ModelSizeInfo modelSizeInfo;
|
||||
private final TrainedModelDefinitionChunk trainedModelDefinitionChunk;
|
||||
|
||||
public AnalyticsResult(@Nullable RowResults rowResults,
|
||||
@Nullable PhaseProgress phaseProgress,
|
||||
@Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
|
||||
@Nullable MemoryUsage memoryUsage,
|
||||
@Nullable OutlierDetectionStats outlierDetectionStats,
|
||||
@Nullable ClassificationStats classificationStats,
|
||||
@Nullable RegressionStats regressionStats,
|
||||
@Nullable ModelSizeInfo modelSizeInfo) {
|
||||
@Nullable ModelSizeInfo modelSizeInfo,
|
||||
@Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
|
||||
this.rowResults = rowResults;
|
||||
this.phaseProgress = phaseProgress;
|
||||
this.inferenceModelBuilder = inferenceModelBuilder;
|
||||
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
|
||||
this.memoryUsage = memoryUsage;
|
||||
this.outlierDetectionStats = outlierDetectionStats;
|
||||
this.classificationStats = classificationStats;
|
||||
this.regressionStats = regressionStats;
|
||||
this.modelSizeInfo = modelSizeInfo;
|
||||
this.trainedModelDefinitionChunk = trainedModelDefinitionChunk;
|
||||
}
|
||||
|
||||
public RowResults getRowResults() {
|
||||
|
@ -99,10 +92,6 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
return phaseProgress;
|
||||
}
|
||||
|
||||
public TrainedModelDefinition.Builder getInferenceModelBuilder() {
|
||||
return inferenceModelBuilder;
|
||||
}
|
||||
|
||||
public MemoryUsage getMemoryUsage() {
|
||||
return memoryUsage;
|
||||
}
|
||||
|
@ -123,6 +112,10 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
return modelSizeInfo;
|
||||
}
|
||||
|
||||
public TrainedModelDefinitionChunk getTrainedModelDefinitionChunk() {
|
||||
return trainedModelDefinitionChunk;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
@ -132,11 +125,6 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
if (phaseProgress != null) {
|
||||
builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress);
|
||||
}
|
||||
if (inferenceModel != null) {
|
||||
builder.field(INFERENCE_MODEL.getPreferredName(),
|
||||
inferenceModel,
|
||||
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")));
|
||||
}
|
||||
if (memoryUsage != null) {
|
||||
builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params);
|
||||
}
|
||||
|
@ -152,6 +140,9 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
if (modelSizeInfo != null) {
|
||||
builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo);
|
||||
}
|
||||
if (trainedModelDefinitionChunk != null) {
|
||||
builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -168,17 +159,17 @@ public class AnalyticsResult implements ToXContentObject {
|
|||
AnalyticsResult that = (AnalyticsResult) other;
|
||||
return Objects.equals(rowResults, that.rowResults)
|
||||
&& Objects.equals(phaseProgress, that.phaseProgress)
|
||||
&& Objects.equals(inferenceModel, that.inferenceModel)
|
||||
&& Objects.equals(memoryUsage, that.memoryUsage)
|
||||
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
|
||||
&& Objects.equals(classificationStats, that.classificationStats)
|
||||
&& Objects.equals(modelSizeInfo, that.modelSizeInfo)
|
||||
&& Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk)
|
||||
&& Objects.equals(regressionStats, that.regressionStats);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
|
||||
regressionStats);
|
||||
return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats,
|
||||
regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.results;
|
||||
|
||||
import org.elasticsearch.common.ParseField;
|
||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||
|
||||
public class TrainedModelDefinitionChunk implements ToXContentObject {
|
||||
|
||||
private static final ParseField DEFINITION = new ParseField("definition");
|
||||
private static final ParseField DOC_NUM = new ParseField("doc_num");
|
||||
private static final ParseField EOS = new ParseField("eos");
|
||||
|
||||
public static final ConstructingObjectParser<TrainedModelDefinitionChunk, Void> PARSER = new ConstructingObjectParser<>(
|
||||
"chunked_trained_model_definition",
|
||||
a -> new TrainedModelDefinitionChunk((String) a[0], (Integer) a[1], (Boolean) a[2]));
|
||||
|
||||
static {
|
||||
PARSER.declareString(constructorArg(), DEFINITION);
|
||||
PARSER.declareInt(constructorArg(), DOC_NUM);
|
||||
PARSER.declareBoolean(optionalConstructorArg(), EOS);
|
||||
}
|
||||
|
||||
private final String definition;
|
||||
private final int docNum;
|
||||
private final Boolean eos;
|
||||
|
||||
public TrainedModelDefinitionChunk(String definition, int docNum, Boolean eos) {
|
||||
this.definition = definition;
|
||||
this.docNum = docNum;
|
||||
this.eos = eos;
|
||||
}
|
||||
|
||||
public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId) {
|
||||
return new TrainedModelDefinitionDoc.Builder()
|
||||
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
||||
.setModelId(modelId)
|
||||
.setDefinitionLength(definition.length())
|
||||
.setDocNum(docNum)
|
||||
.setCompressedString(definition)
|
||||
.setEos(isEos())
|
||||
.build();
|
||||
}
|
||||
|
||||
public boolean isEos() {
|
||||
return eos != null && eos;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(DEFINITION.getPreferredName(), definition);
|
||||
builder.field(DOC_NUM.getPreferredName(), docNum);
|
||||
if (eos != null) {
|
||||
builder.field(EOS.getPreferredName(), eos);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
TrainedModelDefinitionChunk that = (TrainedModelDefinitionChunk) o;
|
||||
return docNum == that.docNum
|
||||
&& Objects.equals(definition, that.definition)
|
||||
&& Objects.equals(eos, that.eos);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(definition, docNum, eos);
|
||||
}
|
||||
}
|
|
@ -33,6 +33,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version");
|
||||
public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length");
|
||||
public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length");
|
||||
public static final ParseField EOS = new ParseField("eos");
|
||||
|
||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||
public static final ObjectParser<TrainedModelDefinitionDoc.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||
|
@ -48,6 +49,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION);
|
||||
parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH);
|
||||
parser.declareLong(TrainedModelDefinitionDoc.Builder::setTotalDefinitionLength, TOTAL_DEFINITION_LENGTH);
|
||||
parser.declareBoolean(TrainedModelDefinitionDoc.Builder::setEos, EOS);
|
||||
return parser;
|
||||
}
|
||||
|
||||
|
@ -63,23 +65,26 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
private final String compressedString;
|
||||
private final String modelId;
|
||||
private final int docNum;
|
||||
private final long totalDefinitionLength;
|
||||
// for BWC
|
||||
private final Long totalDefinitionLength;
|
||||
private final long definitionLength;
|
||||
private final int compressionVersion;
|
||||
private final boolean eos;
|
||||
|
||||
private TrainedModelDefinitionDoc(String compressedString,
|
||||
String modelId,
|
||||
int docNum,
|
||||
long totalDefinitionLength,
|
||||
Long totalDefinitionLength,
|
||||
long definitionLength,
|
||||
int compressionVersion) {
|
||||
int compressionVersion,
|
||||
boolean eos) {
|
||||
this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION);
|
||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
||||
if (docNum < 0) {
|
||||
throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0");
|
||||
}
|
||||
this.docNum = docNum;
|
||||
if (totalDefinitionLength <= 0L) {
|
||||
if (totalDefinitionLength != null && totalDefinitionLength <= 0L) {
|
||||
throw new IllegalArgumentException("[total_definition_length] must be greater than 0");
|
||||
}
|
||||
this.totalDefinitionLength = totalDefinitionLength;
|
||||
|
@ -88,6 +93,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
}
|
||||
this.definitionLength = definitionLength;
|
||||
this.compressionVersion = compressionVersion;
|
||||
this.eos = eos;
|
||||
}
|
||||
|
||||
public String getCompressedString() {
|
||||
|
@ -102,7 +108,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
return docNum;
|
||||
}
|
||||
|
||||
public long getTotalDefinitionLength() {
|
||||
public Long getTotalDefinitionLength() {
|
||||
return totalDefinitionLength;
|
||||
}
|
||||
|
||||
|
@ -114,16 +120,24 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
return compressionVersion;
|
||||
}
|
||||
|
||||
public boolean isEos() {
|
||||
return eos;
|
||||
}
|
||||
|
||||
public String getDocId() {
|
||||
return docId(modelId, docNum);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
||||
builder.field(DOC_NUM.getPreferredName(), docNum);
|
||||
builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength);
|
||||
builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength);
|
||||
builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion);
|
||||
builder.field(DEFINITION.getPreferredName(), compressedString);
|
||||
builder.field(EOS.getPreferredName(), eos);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
@ -143,12 +157,13 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
Objects.equals(definitionLength, that.definitionLength) &&
|
||||
Objects.equals(totalDefinitionLength, that.totalDefinitionLength) &&
|
||||
Objects.equals(compressionVersion, that.compressionVersion) &&
|
||||
Objects.equals(eos, that.eos) &&
|
||||
Objects.equals(compressedString, that.compressedString);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(modelId, docNum, totalDefinitionLength, definitionLength, compressionVersion, compressedString);
|
||||
return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
@ -156,9 +171,10 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
private String modelId;
|
||||
private String compressedString;
|
||||
private int docNum;
|
||||
private long totalDefinitionLength;
|
||||
private Long totalDefinitionLength;
|
||||
private long definitionLength;
|
||||
private int compressionVersion;
|
||||
private boolean eos;
|
||||
|
||||
public Builder setModelId(String modelId) {
|
||||
this.modelId = modelId;
|
||||
|
@ -190,6 +206,11 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
return this;
|
||||
}
|
||||
|
||||
public Builder setEos(boolean eos) {
|
||||
this.eos = eos;
|
||||
return this;
|
||||
}
|
||||
|
||||
public TrainedModelDefinitionDoc build() {
|
||||
return new TrainedModelDefinitionDoc(
|
||||
this.compressedString,
|
||||
|
@ -197,7 +218,8 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
|||
this.docNum,
|
||||
this.totalDefinitionLength,
|
||||
this.definitionLength,
|
||||
this.compressionVersion);
|
||||
this.compressionVersion,
|
||||
this.eos);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,10 +14,14 @@ import org.elasticsearch.ResourceAlreadyExistsException;
|
|||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.DocWriteRequest;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
||||
import org.elasticsearch.action.bulk.BulkAction;
|
||||
import org.elasticsearch.action.bulk.BulkItemResponse;
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexAction;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.MultiSearchAction;
|
||||
import org.elasticsearch.action.search.MultiSearchRequest;
|
||||
|
@ -143,6 +147,74 @@ public class TrainedModelProvider {
|
|||
storeTrainedModelAndDefinition(trainedModelConfig, listener);
|
||||
}
|
||||
|
||||
public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig,
|
||||
ActionListener<Boolean> listener) {
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||
return;
|
||||
}
|
||||
assert trainedModelConfig.getModelDefinition() == null;
|
||||
|
||||
executeAsyncWithOrigin(client,
|
||||
ML_ORIGIN,
|
||||
IndexAction.INSTANCE,
|
||||
createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig),
|
||||
ActionListener.wrap(
|
||||
indexResponse -> listener.onResponse(true),
|
||||
e -> {
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||
} else {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e,
|
||||
trainedModelConfig.getModelId()));
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc, ActionListener<Void> listener) {
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelDefinitionDoc.getModelId())) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelDefinitionDoc.getModelId())));
|
||||
return;
|
||||
}
|
||||
|
||||
executeAsyncWithOrigin(client,
|
||||
ML_ORIGIN,
|
||||
IndexAction.INSTANCE,
|
||||
createRequest(trainedModelDefinitionDoc.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelDefinitionDoc),
|
||||
ActionListener.wrap(
|
||||
indexResponse -> listener.onResponse(null),
|
||||
e -> {
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||
listener.onFailure(new ResourceAlreadyExistsException(
|
||||
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_DOC_EXISTS,
|
||||
trainedModelDefinitionDoc.getModelId(),
|
||||
trainedModelDefinitionDoc.getDocNum())));
|
||||
} else {
|
||||
listener.onFailure(
|
||||
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e,
|
||||
trainedModelDefinitionDoc.getModelId()));
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
public void refreshInferenceIndex(ActionListener<RefreshResponse> listener) {
|
||||
executeAsyncWithOrigin(client,
|
||||
ML_ORIGIN,
|
||||
RefreshAction.INSTANCE,
|
||||
new RefreshRequest(InferenceIndexConstants.INDEX_PATTERN),
|
||||
listener);
|
||||
}
|
||||
|
||||
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
|
||||
ActionListener<Boolean> listener) {
|
||||
|
||||
|
@ -165,7 +237,8 @@ public class TrainedModelProvider {
|
|||
.setCompressedString(chunkedStrings.get(i))
|
||||
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
||||
.setDefinitionLength(chunkedStrings.get(i).length())
|
||||
.setTotalDefinitionLength(compressedString.length())
|
||||
// If it is the last doc, it is the EOS
|
||||
.setEos(i == chunkedStrings.size() - 1)
|
||||
.build());
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
|
@ -265,6 +338,9 @@ public class TrainedModelProvider {
|
|||
.unmappedType("long"))
|
||||
.request();
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
|
||||
// TODO how could we stream in the model definition WHILE parsing it?
|
||||
// This would reduce the overall memory usage as we won't have to load the whole compressed string
|
||||
// XContentParser supports streams.
|
||||
searchResponse -> {
|
||||
if (searchResponse.getHits().getHits().length == 0) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
|
@ -274,19 +350,16 @@ public class TrainedModelProvider {
|
|||
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
|
||||
modelId,
|
||||
this::parseModelDefinitionDocLenientlyFromSource);
|
||||
String compressedString = docs.stream()
|
||||
.map(TrainedModelDefinitionDoc::getCompressedString)
|
||||
.collect(Collectors.joining());
|
||||
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
||||
listener.onFailure(ExceptionsHelper.serverError(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||
return;
|
||||
try {
|
||||
String compressedString = getDefinitionFromDocs(docs, modelId);
|
||||
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
|
||||
compressedString,
|
||||
InferenceDefinition::fromXContent,
|
||||
xContentRegistry);
|
||||
listener.onResponse(inferenceDefinition);
|
||||
} catch (ElasticsearchException elasticsearchException) {
|
||||
listener.onFailure(elasticsearchException);
|
||||
}
|
||||
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
|
||||
compressedString,
|
||||
InferenceDefinition::fromXContent,
|
||||
xContentRegistry);
|
||||
listener.onResponse(inferenceDefinition);
|
||||
},
|
||||
e -> {
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
||||
|
@ -361,15 +434,14 @@ public class TrainedModelProvider {
|
|||
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
|
||||
modelId,
|
||||
this::parseModelDefinitionDocLenientlyFromSource);
|
||||
String compressedString = docs.stream()
|
||||
.map(TrainedModelDefinitionDoc::getCompressedString)
|
||||
.collect(Collectors.joining());
|
||||
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
||||
listener.onFailure(ExceptionsHelper.serverError(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||
try {
|
||||
String compressedString = getDefinitionFromDocs(docs, modelId);
|
||||
builder.setDefinitionFromString(compressedString);
|
||||
} catch (ElasticsearchException elasticsearchException) {
|
||||
listener.onFailure(elasticsearchException);
|
||||
return;
|
||||
}
|
||||
builder.setDefinitionFromString(compressedString);
|
||||
|
||||
} catch (ResourceNotFoundException ex) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
|
@ -806,6 +878,26 @@ public class TrainedModelProvider {
|
|||
return results;
|
||||
}
|
||||
|
||||
private static String getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId) throws ElasticsearchException {
|
||||
String compressedString = docs.stream()
|
||||
.map(TrainedModelDefinitionDoc::getCompressedString)
|
||||
.collect(Collectors.joining());
|
||||
// BWC for when we tracked the total definition length
|
||||
// TODO: remove in 9
|
||||
if (docs.get(0).getTotalDefinitionLength() != null) {
|
||||
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
||||
throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId));
|
||||
}
|
||||
} else {
|
||||
TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1);
|
||||
// Either we are missing the last doc, or some previous doc
|
||||
if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) {
|
||||
throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId));
|
||||
}
|
||||
}
|
||||
return compressedString;
|
||||
}
|
||||
|
||||
static List<String> chunkStringWithSize(String str, int chunkSize) {
|
||||
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
|
||||
for (int i = 0; i < str.length();i += chunkSize) {
|
||||
|
@ -836,14 +928,18 @@ public class TrainedModelProvider {
|
|||
}
|
||||
}
|
||||
|
||||
private IndexRequest createRequest(String docId, String index, ToXContentObject body) {
|
||||
return createRequest(new IndexRequest(index), docId, body);
|
||||
}
|
||||
|
||||
private IndexRequest createRequest(String docId, ToXContentObject body) {
|
||||
return createRequest(new IndexRequest(), docId, body);
|
||||
}
|
||||
|
||||
private IndexRequest createRequest(IndexRequest request, String docId, ToXContentObject body) {
|
||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
|
||||
|
||||
return new IndexRequest()
|
||||
.opType(DocWriteRequest.OpType.CREATE)
|
||||
.id(docId)
|
||||
.source(source);
|
||||
return request.opType(DocWriteRequest.OpType.CREATE).id(docId).source(source);
|
||||
} catch (IOException ex) {
|
||||
// This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again
|
||||
// that is not the users fault. We did something wrong and should throw.
|
||||
|
|
|
@ -5,32 +5,20 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.junit.Before;
|
||||
|
@ -38,20 +26,14 @@ import org.mockito.ArgumentCaptor;
|
|||
import org.mockito.InOrder;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -156,90 +138,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
|
||||
givenDataFrameRows(0);
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
||||
storeListener.onResponse(true);
|
||||
return null;
|
||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||
|
||||
List<ExtractedField> extractedFieldList = new ArrayList<>(3);
|
||||
extractedFieldList.add(new DocValueField("foo", Collections.emptySet()));
|
||||
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
|
||||
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
|
||||
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
|
||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
|
||||
|
||||
resultProcessor.process(process);
|
||||
resultProcessor.awaitForCompletion();
|
||||
|
||||
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
|
||||
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
|
||||
|
||||
TrainedModelConfig storedModel = storedModelCaptor.getValue();
|
||||
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
|
||||
assertThat(storedModel.getModelId(), containsString(JOB_ID));
|
||||
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
|
||||
assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME));
|
||||
assertThat(storedModel.getTags(), contains(JOB_ID));
|
||||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
|
||||
assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword")));
|
||||
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz")));
|
||||
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
|
||||
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
|
||||
if (targetType.equals(TargetType.CLASSIFICATION)) {
|
||||
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
|
||||
} else {
|
||||
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
|
||||
}
|
||||
Map<String, Object> metadata = storedModel.getMetadata();
|
||||
assertThat(metadata.size(), equalTo(1));
|
||||
assertThat(metadata, hasKey("analytics_config"));
|
||||
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
|
||||
true);
|
||||
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
|
||||
|
||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
||||
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
|
||||
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
|
||||
Mockito.verifyNoMoreInteractions(auditor);
|
||||
}
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testProcess_GivenInferenceModelFailedToStore() {
|
||||
givenDataFrameRows(0);
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
||||
storeListener.onFailure(new RuntimeException("some failure"));
|
||||
return null;
|
||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||
|
||||
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
|
||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
|
||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
|
||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
||||
|
||||
resultProcessor.process(process);
|
||||
resultProcessor.awaitForCompletion();
|
||||
|
||||
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
|
||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
||||
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
|
||||
assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID));
|
||||
Mockito.verifyNoMoreInteractions(auditor);
|
||||
|
||||
assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
|
||||
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
||||
}
|
||||
|
||||
private void givenProcessResults(List<AnalyticsResult> results) {
|
||||
when(process.readAnalyticsResults()).thenReturn(results.iterator());
|
||||
}
|
||||
|
@ -256,7 +154,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private AnalyticsResultProcessor createResultProcessor(List<ExtractedField> fieldNames) {
|
||||
|
||||
return new AnalyticsResultProcessor(analyticsConfig,
|
||||
dataFrameRowsJoiner,
|
||||
statsHolder,
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasKey;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
public class ChunkedTrainedModelPersisterTests extends ESTestCase {
|
||||
|
||||
private static final String JOB_ID = "analytics-result-processor-tests";
|
||||
private static final String JOB_DESCRIPTION = "This describes the job of these tests";
|
||||
|
||||
private TrainedModelProvider trainedModelProvider;
|
||||
private DataFrameAnalyticsAuditor auditor;
|
||||
|
||||
@Before
|
||||
public void setUpMocks() {
|
||||
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||
auditor = mock(DataFrameAnalyticsAuditor.class);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testPersistAllDocs() {
|
||||
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
||||
.setId(JOB_ID)
|
||||
.setDescription(JOB_DESCRIPTION)
|
||||
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
|
||||
.setDest(new DataFrameAnalyticsDest("my_dest", null))
|
||||
.setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo"))
|
||||
.build();
|
||||
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
||||
storeListener.onResponse(true);
|
||||
return null;
|
||||
}).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||
|
||||
doAnswer(invocationOnMock -> {
|
||||
ActionListener<Void> storeListener = (ActionListener<Void>) invocationOnMock.getArguments()[1];
|
||||
storeListener.onResponse(null);
|
||||
return null;
|
||||
}).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class));
|
||||
|
||||
ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig);
|
||||
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||
TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false);
|
||||
TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true);
|
||||
|
||||
resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo);
|
||||
resultProcessor.createAndIndexInferenceModelDoc(chunk1);
|
||||
resultProcessor.createAndIndexInferenceModelDoc(chunk2);
|
||||
|
||||
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
|
||||
verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class));
|
||||
|
||||
ArgumentCaptor<TrainedModelDefinitionDoc> storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class);
|
||||
verify(trainedModelProvider, times(2))
|
||||
.storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class));
|
||||
|
||||
TrainedModelConfig storedModel = storedModelCaptor.getValue();
|
||||
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
|
||||
assertThat(storedModel.getModelId(), containsString(JOB_ID));
|
||||
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
|
||||
assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME));
|
||||
assertThat(storedModel.getTags(), contains(JOB_ID));
|
||||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||
assertThat(storedModel.getModelDefinition(), is(nullValue()));
|
||||
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||
assertThat(storedModel.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||
if (analyticsConfig.getAnalysis() instanceof Classification) {
|
||||
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
|
||||
} else {
|
||||
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
|
||||
}
|
||||
Map<String, Object> metadata = storedModel.getMetadata();
|
||||
assertThat(metadata.size(), equalTo(1));
|
||||
assertThat(metadata, hasKey("analytics_config"));
|
||||
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
|
||||
true);
|
||||
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
|
||||
|
||||
TrainedModelDefinitionDoc storedDoc1 = storedDocCapture.getAllValues().get(0);
|
||||
assertThat(storedDoc1.getDocNum(), equalTo(0));
|
||||
TrainedModelDefinitionDoc storedDoc2 = storedDocCapture.getAllValues().get(1);
|
||||
assertThat(storedDoc2.getDocNum(), equalTo(1));
|
||||
|
||||
assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId()));
|
||||
assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId()));
|
||||
|
||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
||||
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
|
||||
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
|
||||
Mockito.verifyNoMoreInteractions(auditor);
|
||||
}
|
||||
|
||||
private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List<ExtractedField> fieldNames,
|
||||
DataFrameAnalyticsConfig analyticsConfig) {
|
||||
return new ChunkedTrainedModelPersister(trainedModelProvider,
|
||||
analyticsConfig,
|
||||
auditor,
|
||||
(unused)->{},
|
||||
new ExtractedFields(fieldNames, Collections.emptyMap()));
|
||||
}
|
||||
|
||||
}
|
|
@ -20,8 +20,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||
|
@ -46,21 +44,18 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
protected AnalyticsResult createTestInstance() {
|
||||
RowResults rowResults = null;
|
||||
PhaseProgress phaseProgress = null;
|
||||
TrainedModelDefinition.Builder inferenceModel = null;
|
||||
MemoryUsage memoryUsage = null;
|
||||
OutlierDetectionStats outlierDetectionStats = null;
|
||||
ClassificationStats classificationStats = null;
|
||||
RegressionStats regressionStats = null;
|
||||
ModelSizeInfo modelSizeInfo = null;
|
||||
TrainedModelDefinitionChunk trainedModelDefinitionChunk = null;
|
||||
if (randomBoolean()) {
|
||||
rowResults = RowResultsTests.createRandom();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
memoryUsage = MemoryUsageTests.createRandom();
|
||||
}
|
||||
|
@ -76,8 +71,12 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
|||
if (randomBoolean()) {
|
||||
modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||
}
|
||||
return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats,
|
||||
classificationStats, regressionStats, modelSizeInfo);
|
||||
if (randomBoolean()) {
|
||||
String def = randomAlphaOfLengthBetween(100, 1000);
|
||||
trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean());
|
||||
}
|
||||
return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats,
|
||||
classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue