* [ML] handles compressed model stream from native process (#58009) This moves model storage from handling the fully parsed JSON string to handling two separate types of documents. 1. ModelSizeInfo which contains model size information 2. TrainedModelDefinitionChunk which contains a particular chunk of the compressed model definition string. `model_size_info` is assumed to be handled first. This will generate the model_id and store the initial trained model config object. Then each chunk is assumed to be in correct order for concatenating the chunks to get a compressed definition. Native side change: https://github.com/elastic/ml-cpp/pull/1349
This commit is contained in:
parent
9c77862a23
commit
c64e283dbf
|
@ -89,6 +89,7 @@ public final class Messages {
|
||||||
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
|
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
|
||||||
|
|
||||||
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
|
||||||
|
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
|
||||||
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
|
||||||
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
|
||||||
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
|
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
|
||||||
import org.elasticsearch.ElasticsearchException;
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.ActionModule;
|
import org.elasticsearch.action.ActionModule;
|
||||||
|
@ -67,7 +66,6 @@ import static org.hamcrest.Matchers.not;
|
||||||
import static org.hamcrest.Matchers.nullValue;
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
import static org.hamcrest.Matchers.startsWith;
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
|
|
||||||
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
|
|
||||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
private static final String BOOLEAN_FIELD = "boolean-field";
|
private static final String BOOLEAN_FIELD = "boolean-field";
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
|
||||||
import org.elasticsearch.ElasticsearchException;
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.action.ActionModule;
|
import org.elasticsearch.action.ActionModule;
|
||||||
import org.elasticsearch.action.DocWriteRequest;
|
import org.elasticsearch.action.DocWriteRequest;
|
||||||
|
@ -45,7 +44,6 @@ import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.lessThan;
|
import static org.hamcrest.Matchers.lessThan;
|
||||||
import static org.hamcrest.Matchers.not;
|
import static org.hamcrest.Matchers.not;
|
||||||
|
|
||||||
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
|
|
||||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.license.License;
|
||||||
|
import org.elasticsearch.xpack.core.action.util.PageParams;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
||||||
|
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
|
||||||
|
|
||||||
|
private TrainedModelProvider trainedModelProvider;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void createComponents() throws Exception {
|
||||||
|
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
|
||||||
|
waitForMlTemplates();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testStoreModelViaChunkedPersister() throws IOException {
|
||||||
|
String modelId = "stored-chunked-model";
|
||||||
|
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
||||||
|
.setId(modelId)
|
||||||
|
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
|
||||||
|
.setDest(new DataFrameAnalyticsDest("my_dest", null))
|
||||||
|
.setAnalysis(new Regression("foo"))
|
||||||
|
.build();
|
||||||
|
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
|
||||||
|
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
|
||||||
|
String compressedDefinition = configBuilder.build().getCompressedDefinition();
|
||||||
|
int totalSize = compressedDefinition.length();
|
||||||
|
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);
|
||||||
|
|
||||||
|
ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
|
||||||
|
analyticsConfig,
|
||||||
|
new DataFrameAnalyticsAuditor(client(), "test-node"),
|
||||||
|
(ex) -> { throw new ElasticsearchException(ex); },
|
||||||
|
new ExtractedFields(extractedFieldList, Collections.emptyMap())
|
||||||
|
);
|
||||||
|
|
||||||
|
//Accuracy for size is not tested here
|
||||||
|
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||||
|
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
|
||||||
|
for (int i = 0; i < chunks.size(); i++) {
|
||||||
|
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
|
||||||
|
}
|
||||||
|
|
||||||
|
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
|
||||||
|
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
|
||||||
|
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
|
||||||
|
assertThat(ids.v1(), equalTo(1L));
|
||||||
|
|
||||||
|
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
|
||||||
|
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
|
||||||
|
|
||||||
|
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
|
||||||
|
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
|
||||||
|
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||||
|
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||||
|
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
|
||||||
|
long bytesUsed = definitionBuilder.build().ramBytesUsed();
|
||||||
|
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
|
||||||
|
return TrainedModelConfig.builder()
|
||||||
|
.setCreatedBy("ml_test")
|
||||||
|
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
|
||||||
|
.setDescription("trained model config for test")
|
||||||
|
.setModelId(modelId)
|
||||||
|
.setVersion(Version.CURRENT)
|
||||||
|
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||||
|
.setEstimatedHeapMemory(bytesUsed)
|
||||||
|
.setEstimatedOperations(operations)
|
||||||
|
.setInput(TrainedModelInputTests.createRandomInput());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<String> chunkStringWithSize(String str, int chunkSize) {
|
||||||
|
List<String> subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize);
|
||||||
|
for (int i = 0; i < str.length(); i += chunkSize) {
|
||||||
|
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
|
||||||
|
}
|
||||||
|
return subStrings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NamedXContentRegistry xContentRegistry() {
|
||||||
|
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
||||||
|
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
return new NamedXContentRegistry(namedXContent);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -32,8 +32,11 @@ import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
||||||
|
import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize;
|
||||||
import static org.hamcrest.CoreMatchers.is;
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.not;
|
import static org.hamcrest.Matchers.not;
|
||||||
|
@ -157,8 +160,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetTruncatedModelDefinition() throws Exception {
|
public void testGetTruncatedModelDeprecatedDefinition() throws Exception {
|
||||||
String modelId = "test-get-truncated-model-config";
|
String modelId = "test-get-truncated-legacy-model-config";
|
||||||
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||||
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||||
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||||
|
@ -196,6 +199,51 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
|
||||||
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGetTruncatedModelDefinition() throws Exception {
|
||||||
|
String modelId = "test-get-truncated-model-config";
|
||||||
|
TrainedModelConfig config = buildTrainedModelConfig(modelId);
|
||||||
|
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
|
||||||
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
||||||
|
|
||||||
|
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
|
||||||
|
assertThat(putConfigHolder.get(), is(true));
|
||||||
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
|
||||||
|
List<String> chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3);
|
||||||
|
|
||||||
|
List<TrainedModelDefinitionDoc.Builder> docBuilders = IntStream.range(0, chunks.size())
|
||||||
|
.mapToObj(i -> new TrainedModelDefinitionDoc.Builder()
|
||||||
|
.setDocNum(i)
|
||||||
|
.setCompressedString(chunks.get(i))
|
||||||
|
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
||||||
|
.setDefinitionLength(chunks.get(i).length())
|
||||||
|
.setEos(i == chunks.size() - 1)
|
||||||
|
.setModelId(modelId))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
boolean missingEos = randomBoolean();
|
||||||
|
docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false);
|
||||||
|
for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) {
|
||||||
|
TrainedModelDefinitionDoc doc = docBuilders.get(i).build();
|
||||||
|
try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(),
|
||||||
|
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) {
|
||||||
|
AtomicReference<IndexResponse> putDocHolder = new AtomicReference<>();
|
||||||
|
blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
|
||||||
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||||
|
.setSource(xContentBuilder)
|
||||||
|
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
|
||||||
|
.execute(listener),
|
||||||
|
putDocHolder,
|
||||||
|
exceptionHolder);
|
||||||
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
|
||||||
|
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
|
||||||
|
assertThat(getConfigHolder.get(), is(nullValue()));
|
||||||
|
assertThat(exceptionHolder.get(), is(not(nullValue())));
|
||||||
|
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
||||||
|
}
|
||||||
|
|
||||||
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
|
||||||
return TrainedModelConfig.builder()
|
return TrainedModelConfig.builder()
|
||||||
.setCreatedBy("ml_test")
|
.setCreatedBy("ml_test")
|
||||||
|
|
|
@ -8,48 +8,29 @@ package org.elasticsearch.xpack.ml.dataframe.process;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.elasticsearch.Version;
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
|
||||||
import org.elasticsearch.action.LatchedActionListener;
|
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
||||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
|
||||||
import org.elasticsearch.license.License;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
|
||||||
import java.time.Instant;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static java.util.stream.Collectors.toList;
|
|
||||||
|
|
||||||
public class AnalyticsResultProcessor {
|
public class AnalyticsResultProcessor {
|
||||||
|
|
||||||
|
@ -70,11 +51,10 @@ public class AnalyticsResultProcessor {
|
||||||
private final DataFrameAnalyticsConfig analytics;
|
private final DataFrameAnalyticsConfig analytics;
|
||||||
private final DataFrameRowsJoiner dataFrameRowsJoiner;
|
private final DataFrameRowsJoiner dataFrameRowsJoiner;
|
||||||
private final StatsHolder statsHolder;
|
private final StatsHolder statsHolder;
|
||||||
private final TrainedModelProvider trainedModelProvider;
|
|
||||||
private final DataFrameAnalyticsAuditor auditor;
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
private final StatsPersister statsPersister;
|
private final StatsPersister statsPersister;
|
||||||
private final ExtractedFields extractedFields;
|
|
||||||
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
||||||
|
private final ChunkedTrainedModelPersister chunkedTrainedModelPersister;
|
||||||
private volatile String failure;
|
private volatile String failure;
|
||||||
private volatile boolean isCancelled;
|
private volatile boolean isCancelled;
|
||||||
|
|
||||||
|
@ -84,10 +64,15 @@ public class AnalyticsResultProcessor {
|
||||||
this.analytics = Objects.requireNonNull(analytics);
|
this.analytics = Objects.requireNonNull(analytics);
|
||||||
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
||||||
this.statsHolder = Objects.requireNonNull(statsHolder);
|
this.statsHolder = Objects.requireNonNull(statsHolder);
|
||||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
|
||||||
this.auditor = Objects.requireNonNull(auditor);
|
this.auditor = Objects.requireNonNull(auditor);
|
||||||
this.statsPersister = Objects.requireNonNull(statsPersister);
|
this.statsPersister = Objects.requireNonNull(statsPersister);
|
||||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
this.chunkedTrainedModelPersister = new ChunkedTrainedModelPersister(
|
||||||
|
trainedModelProvider,
|
||||||
|
analytics,
|
||||||
|
auditor,
|
||||||
|
this::setAndReportFailure,
|
||||||
|
extractedFields
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
|
@ -166,9 +151,13 @@ public class AnalyticsResultProcessor {
|
||||||
phaseProgress.getProgressPercent());
|
phaseProgress.getProgressPercent());
|
||||||
statsHolder.getProgressTracker().updatePhase(phaseProgress);
|
statsHolder.getProgressTracker().updatePhase(phaseProgress);
|
||||||
}
|
}
|
||||||
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
|
ModelSizeInfo modelSize = result.getModelSizeInfo();
|
||||||
if (inferenceModelBuilder != null) {
|
if (modelSize != null) {
|
||||||
createAndIndexInferenceModel(inferenceModelBuilder);
|
chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize);
|
||||||
|
}
|
||||||
|
TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk();
|
||||||
|
if (trainedModelDefinitionChunk != null) {
|
||||||
|
chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk);
|
||||||
}
|
}
|
||||||
MemoryUsage memoryUsage = result.getMemoryUsage();
|
MemoryUsage memoryUsage = result.getMemoryUsage();
|
||||||
if (memoryUsage != null) {
|
if (memoryUsage != null) {
|
||||||
|
@ -191,82 +180,6 @@ public class AnalyticsResultProcessor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) {
|
|
||||||
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel);
|
|
||||||
CountDownLatch latch = storeTrainedModel(trainedModelConfig);
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (latch.await(30, TimeUnit.SECONDS) == false) {
|
|
||||||
LOGGER.error("[{}] Timed out (30s) waiting for inference model to be stored", analytics.getId());
|
|
||||||
}
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
Thread.currentThread().interrupt();
|
|
||||||
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) {
|
|
||||||
Instant createTime = Instant.now();
|
|
||||||
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
|
||||||
TrainedModelDefinition definition = inferenceModel.build();
|
|
||||||
String dependentVariable = getDependentVariable();
|
|
||||||
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
|
||||||
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
|
|
||||||
.map(ExtractedField::getName)
|
|
||||||
.filter(f -> f.equals(dependentVariable) == false)
|
|
||||||
.collect(toList());
|
|
||||||
Map<String, String> defaultFieldMapping = fieldNames.stream()
|
|
||||||
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
|
|
||||||
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
|
|
||||||
return TrainedModelConfig.builder()
|
|
||||||
.setModelId(modelId)
|
|
||||||
.setCreatedBy(XPackUser.NAME)
|
|
||||||
.setVersion(Version.CURRENT)
|
|
||||||
.setCreateTime(createTime)
|
|
||||||
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
|
|
||||||
.setTags(Collections.singletonList(analytics.getId()))
|
|
||||||
.setDescription(analytics.getDescription())
|
|
||||||
.setMetadata(Collections.singletonMap("analytics_config",
|
|
||||||
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
|
||||||
.setEstimatedHeapMemory(definition.ramBytesUsed())
|
|
||||||
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
|
|
||||||
.setParsedDefinition(inferenceModel)
|
|
||||||
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
|
||||||
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
|
||||||
.setDefaultFieldMap(defaultFieldMapping)
|
|
||||||
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getDependentVariable() {
|
|
||||||
if (analytics.getAnalysis() instanceof Classification) {
|
|
||||||
return ((Classification)analytics.getAnalysis()).getDependentVariable();
|
|
||||||
}
|
|
||||||
if (analytics.getAnalysis() instanceof Regression) {
|
|
||||||
return ((Regression)analytics.getAnalysis()).getDependentVariable();
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) {
|
|
||||||
CountDownLatch latch = new CountDownLatch(1);
|
|
||||||
ActionListener<Boolean> storeListener = ActionListener.wrap(
|
|
||||||
aBoolean -> {
|
|
||||||
if (aBoolean == false) {
|
|
||||||
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
|
|
||||||
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
|
|
||||||
} else {
|
|
||||||
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
|
|
||||||
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
|
|
||||||
trainedModelConfig.getModelId()))
|
|
||||||
);
|
|
||||||
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
|
|
||||||
return latch;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setAndReportFailure(Exception e) {
|
private void setAndReportFailure(Exception e) {
|
||||||
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
|
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
|
||||||
failure = "error processing results; " + e.getMessage();
|
failure = "error processing results; " + e.getMessage();
|
||||||
|
|
|
@ -0,0 +1,235 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.LatchedActionListener;
|
||||||
|
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.license.License;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static java.util.stream.Collectors.toList;
|
||||||
|
|
||||||
|
public class ChunkedTrainedModelPersister {
|
||||||
|
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(ChunkedTrainedModelPersister.class);
|
||||||
|
private static final int STORE_TIMEOUT_SEC = 30;
|
||||||
|
private final TrainedModelProvider provider;
|
||||||
|
private final AtomicReference<String> currentModelId;
|
||||||
|
private final DataFrameAnalyticsConfig analytics;
|
||||||
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
|
private final Consumer<Exception> failureHandler;
|
||||||
|
private final ExtractedFields extractedFields;
|
||||||
|
private final AtomicBoolean readyToStoreNewModel = new AtomicBoolean(true);
|
||||||
|
|
||||||
|
public ChunkedTrainedModelPersister(TrainedModelProvider provider,
|
||||||
|
DataFrameAnalyticsConfig analytics,
|
||||||
|
DataFrameAnalyticsAuditor auditor,
|
||||||
|
Consumer<Exception> failureHandler,
|
||||||
|
ExtractedFields extractedFields) {
|
||||||
|
this.provider = provider;
|
||||||
|
this.currentModelId = new AtomicReference<>("");
|
||||||
|
this.analytics = analytics;
|
||||||
|
this.auditor = auditor;
|
||||||
|
this.failureHandler = failureHandler;
|
||||||
|
this.extractedFields = extractedFields;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
|
||||||
|
if (Strings.isNullOrEmpty(this.currentModelId.get())) {
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError(
|
||||||
|
"chunked inference model definition is attempting to be stored before trained model configuration"
|
||||||
|
));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TrainedModelDefinitionDoc trainedModelDefinitionDoc = trainedModelDefinitionChunk.createTrainedModelDoc(this.currentModelId.get());
|
||||||
|
|
||||||
|
CountDownLatch latch = storeTrainedModelDoc(trainedModelDefinitionDoc);
|
||||||
|
try {
|
||||||
|
if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
|
||||||
|
LOGGER.error("[{}] Timed out (30s) waiting for chunked inference definition to be stored", analytics.getId());
|
||||||
|
if (trainedModelDefinitionChunk.isEos()) {
|
||||||
|
this.readyToStoreNewModel.set(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
this.readyToStoreNewModel.set(true);
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for chunked inference definition to be stored"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) {
|
||||||
|
if (readyToStoreNewModel.compareAndSet(true, false) == false) {
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError(
|
||||||
|
"new inference model is attempting to be stored before completion previous model storage"
|
||||||
|
));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize);
|
||||||
|
CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig);
|
||||||
|
try {
|
||||||
|
if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
|
||||||
|
LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId());
|
||||||
|
}
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
this.readyToStoreNewModel.set(true);
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) {
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
|
||||||
|
// Latch is attached to this action as it is the last one to execute.
|
||||||
|
ActionListener<RefreshResponse> refreshListener = new LatchedActionListener<>(ActionListener.wrap(
|
||||||
|
refreshed -> {
|
||||||
|
if (refreshed != null) {
|
||||||
|
LOGGER.debug(() -> new ParameterizedMessage(
|
||||||
|
"[{}] refreshed inference index after model store",
|
||||||
|
analytics.getId()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
e -> LOGGER.warn(
|
||||||
|
new ParameterizedMessage("[{}] failed to refresh inference index after model store", analytics.getId()),
|
||||||
|
e)
|
||||||
|
), latch);
|
||||||
|
|
||||||
|
// First, store the model and refresh is necessary
|
||||||
|
ActionListener<Void> storeListener = ActionListener.wrap(
|
||||||
|
r -> {
|
||||||
|
LOGGER.debug(() -> new ParameterizedMessage(
|
||||||
|
"[{}] stored trained model definition chunk [{}] [{}]",
|
||||||
|
analytics.getId(),
|
||||||
|
trainedModelDefinitionDoc.getModelId(),
|
||||||
|
trainedModelDefinitionDoc.getDocNum()));
|
||||||
|
if (trainedModelDefinitionDoc.isEos() == false) {
|
||||||
|
refreshListener.onResponse(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LOGGER.info(
|
||||||
|
"[{}] finished storing trained model with id [{}]",
|
||||||
|
analytics.getId(),
|
||||||
|
this.currentModelId.get());
|
||||||
|
auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]");
|
||||||
|
this.currentModelId.set("");
|
||||||
|
readyToStoreNewModel.set(true);
|
||||||
|
provider.refreshInferenceIndex(refreshListener);
|
||||||
|
},
|
||||||
|
e -> {
|
||||||
|
this.readyToStoreNewModel.set(true);
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError(
|
||||||
|
"error storing trained model definition chunk [{}] with id [{}]",
|
||||||
|
e,
|
||||||
|
trainedModelDefinitionDoc.getModelId(),
|
||||||
|
trainedModelDefinitionDoc.getDocNum()));
|
||||||
|
refreshListener.onResponse(null);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener);
|
||||||
|
return latch;
|
||||||
|
}
|
||||||
|
private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) {
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
ActionListener<Boolean> storeListener = ActionListener.wrap(
|
||||||
|
aBoolean -> {
|
||||||
|
if (aBoolean == false) {
|
||||||
|
LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId());
|
||||||
|
readyToStoreNewModel.set(true);
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false"));
|
||||||
|
} else {
|
||||||
|
LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
e -> {
|
||||||
|
readyToStoreNewModel.set(true);
|
||||||
|
failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]",
|
||||||
|
e,
|
||||||
|
trainedModelConfig.getModelId()));
|
||||||
|
}
|
||||||
|
);
|
||||||
|
provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
|
||||||
|
return latch;
|
||||||
|
}
|
||||||
|
|
||||||
|
private TrainedModelConfig createTrainedModelConfig(ModelSizeInfo modelSize) {
|
||||||
|
Instant createTime = Instant.now();
|
||||||
|
String modelId = analytics.getId() + "-" + createTime.toEpochMilli();
|
||||||
|
currentModelId.set(modelId);
|
||||||
|
List<ExtractedField> fieldNames = extractedFields.getAllFields();
|
||||||
|
String dependentVariable = getDependentVariable();
|
||||||
|
List<String> fieldNamesWithoutDependentVariable = fieldNames.stream()
|
||||||
|
.map(ExtractedField::getName)
|
||||||
|
.filter(f -> f.equals(dependentVariable) == false)
|
||||||
|
.collect(toList());
|
||||||
|
Map<String, String> defaultFieldMapping = fieldNames.stream()
|
||||||
|
.filter(ef -> ef instanceof MultiField && (ef.getName().equals(dependentVariable) == false))
|
||||||
|
.collect(Collectors.toMap(ExtractedField::getParentField, ExtractedField::getName));
|
||||||
|
return TrainedModelConfig.builder()
|
||||||
|
.setModelId(modelId)
|
||||||
|
.setCreatedBy(XPackUser.NAME)
|
||||||
|
.setVersion(Version.CURRENT)
|
||||||
|
.setCreateTime(createTime)
|
||||||
|
// NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
|
||||||
|
.setTags(Collections.singletonList(analytics.getId()))
|
||||||
|
.setDescription(analytics.getDescription())
|
||||||
|
.setMetadata(Collections.singletonMap("analytics_config",
|
||||||
|
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
|
||||||
|
.setEstimatedHeapMemory(modelSize.ramBytesUsed())
|
||||||
|
.setEstimatedOperations(modelSize.numOperations())
|
||||||
|
.setInput(new TrainedModelInput(fieldNamesWithoutDependentVariable))
|
||||||
|
.setLicenseLevel(License.OperationMode.PLATINUM.description())
|
||||||
|
.setDefaultFieldMap(defaultFieldMapping)
|
||||||
|
.setInferenceConfig(analytics.getAnalysis().inferenceConfig(new AnalysisFieldInfo(extractedFields)))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getDependentVariable() {
|
||||||
|
if (analytics.getAnalysis() instanceof Classification) {
|
||||||
|
return ((Classification)analytics.getAnalysis()).getDependentVariable();
|
||||||
|
}
|
||||||
|
if (analytics.getAnalysis() instanceof Regression) {
|
||||||
|
return ((Regression)analytics.getAnalysis()).getDependentVariable();
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -8,31 +8,27 @@ package org.elasticsearch.xpack.ml.dataframe.process.results;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.ParseField;
|
import org.elasticsearch.common.ParseField;
|
||||||
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
import org.elasticsearch.common.xcontent.ToXContent;
|
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|
||||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||||
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
|
|
||||||
|
|
||||||
public class AnalyticsResult implements ToXContentObject {
|
public class AnalyticsResult implements ToXContentObject {
|
||||||
|
|
||||||
public static final ParseField TYPE = new ParseField("analytics_result");
|
public static final ParseField TYPE = new ParseField("analytics_result");
|
||||||
|
|
||||||
private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
|
private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
|
||||||
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
|
|
||||||
private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info");
|
private static final ParseField MODEL_SIZE_INFO = new ParseField("model_size_info");
|
||||||
|
private static final ParseField COMPRESSED_INFERENCE_MODEL = new ParseField("compressed_inference_model");
|
||||||
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
|
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
|
||||||
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
|
private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
|
||||||
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
|
private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
|
||||||
|
@ -42,53 +38,50 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
a -> new AnalyticsResult(
|
a -> new AnalyticsResult(
|
||||||
(RowResults) a[0],
|
(RowResults) a[0],
|
||||||
(PhaseProgress) a[1],
|
(PhaseProgress) a[1],
|
||||||
(TrainedModelDefinition.Builder) a[2],
|
(MemoryUsage) a[2],
|
||||||
(MemoryUsage) a[3],
|
(OutlierDetectionStats) a[3],
|
||||||
(OutlierDetectionStats) a[4],
|
(ClassificationStats) a[4],
|
||||||
(ClassificationStats) a[5],
|
(RegressionStats) a[5],
|
||||||
(RegressionStats) a[6],
|
(ModelSizeInfo) a[6],
|
||||||
(ModelSizeInfo) a[7]
|
(TrainedModelDefinitionChunk) a[7]
|
||||||
));
|
));
|
||||||
|
|
||||||
static {
|
static {
|
||||||
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
|
||||||
PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS);
|
PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS);
|
||||||
// TODO change back to STRICT_PARSER once native side is aligned
|
|
||||||
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
|
|
||||||
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE);
|
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.STRICT_PARSER, ANALYTICS_MEMORY_USAGE);
|
||||||
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
|
PARSER.declareObject(optionalConstructorArg(), OutlierDetectionStats.STRICT_PARSER, OUTLIER_DETECTION_STATS);
|
||||||
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
|
PARSER.declareObject(optionalConstructorArg(), ClassificationStats.STRICT_PARSER, CLASSIFICATION_STATS);
|
||||||
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
|
PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
|
||||||
PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
|
PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
|
||||||
|
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL);
|
||||||
}
|
}
|
||||||
|
|
||||||
private final RowResults rowResults;
|
private final RowResults rowResults;
|
||||||
private final PhaseProgress phaseProgress;
|
private final PhaseProgress phaseProgress;
|
||||||
private final TrainedModelDefinition.Builder inferenceModelBuilder;
|
|
||||||
private final TrainedModelDefinition inferenceModel;
|
|
||||||
private final MemoryUsage memoryUsage;
|
private final MemoryUsage memoryUsage;
|
||||||
private final OutlierDetectionStats outlierDetectionStats;
|
private final OutlierDetectionStats outlierDetectionStats;
|
||||||
private final ClassificationStats classificationStats;
|
private final ClassificationStats classificationStats;
|
||||||
private final RegressionStats regressionStats;
|
private final RegressionStats regressionStats;
|
||||||
private final ModelSizeInfo modelSizeInfo;
|
private final ModelSizeInfo modelSizeInfo;
|
||||||
|
private final TrainedModelDefinitionChunk trainedModelDefinitionChunk;
|
||||||
|
|
||||||
public AnalyticsResult(@Nullable RowResults rowResults,
|
public AnalyticsResult(@Nullable RowResults rowResults,
|
||||||
@Nullable PhaseProgress phaseProgress,
|
@Nullable PhaseProgress phaseProgress,
|
||||||
@Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
|
|
||||||
@Nullable MemoryUsage memoryUsage,
|
@Nullable MemoryUsage memoryUsage,
|
||||||
@Nullable OutlierDetectionStats outlierDetectionStats,
|
@Nullable OutlierDetectionStats outlierDetectionStats,
|
||||||
@Nullable ClassificationStats classificationStats,
|
@Nullable ClassificationStats classificationStats,
|
||||||
@Nullable RegressionStats regressionStats,
|
@Nullable RegressionStats regressionStats,
|
||||||
@Nullable ModelSizeInfo modelSizeInfo) {
|
@Nullable ModelSizeInfo modelSizeInfo,
|
||||||
|
@Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
|
||||||
this.rowResults = rowResults;
|
this.rowResults = rowResults;
|
||||||
this.phaseProgress = phaseProgress;
|
this.phaseProgress = phaseProgress;
|
||||||
this.inferenceModelBuilder = inferenceModelBuilder;
|
|
||||||
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
|
|
||||||
this.memoryUsage = memoryUsage;
|
this.memoryUsage = memoryUsage;
|
||||||
this.outlierDetectionStats = outlierDetectionStats;
|
this.outlierDetectionStats = outlierDetectionStats;
|
||||||
this.classificationStats = classificationStats;
|
this.classificationStats = classificationStats;
|
||||||
this.regressionStats = regressionStats;
|
this.regressionStats = regressionStats;
|
||||||
this.modelSizeInfo = modelSizeInfo;
|
this.modelSizeInfo = modelSizeInfo;
|
||||||
|
this.trainedModelDefinitionChunk = trainedModelDefinitionChunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
public RowResults getRowResults() {
|
public RowResults getRowResults() {
|
||||||
|
@ -99,10 +92,6 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
return phaseProgress;
|
return phaseProgress;
|
||||||
}
|
}
|
||||||
|
|
||||||
public TrainedModelDefinition.Builder getInferenceModelBuilder() {
|
|
||||||
return inferenceModelBuilder;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MemoryUsage getMemoryUsage() {
|
public MemoryUsage getMemoryUsage() {
|
||||||
return memoryUsage;
|
return memoryUsage;
|
||||||
}
|
}
|
||||||
|
@ -123,6 +112,10 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
return modelSizeInfo;
|
return modelSizeInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public TrainedModelDefinitionChunk getTrainedModelDefinitionChunk() {
|
||||||
|
return trainedModelDefinitionChunk;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
@ -132,11 +125,6 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
if (phaseProgress != null) {
|
if (phaseProgress != null) {
|
||||||
builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress);
|
builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress);
|
||||||
}
|
}
|
||||||
if (inferenceModel != null) {
|
|
||||||
builder.field(INFERENCE_MODEL.getPreferredName(),
|
|
||||||
inferenceModel,
|
|
||||||
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")));
|
|
||||||
}
|
|
||||||
if (memoryUsage != null) {
|
if (memoryUsage != null) {
|
||||||
builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params);
|
builder.field(ANALYTICS_MEMORY_USAGE.getPreferredName(), memoryUsage, params);
|
||||||
}
|
}
|
||||||
|
@ -152,6 +140,9 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
if (modelSizeInfo != null) {
|
if (modelSizeInfo != null) {
|
||||||
builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo);
|
builder.field(MODEL_SIZE_INFO.getPreferredName(), modelSizeInfo);
|
||||||
}
|
}
|
||||||
|
if (trainedModelDefinitionChunk != null) {
|
||||||
|
builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk);
|
||||||
|
}
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -168,17 +159,17 @@ public class AnalyticsResult implements ToXContentObject {
|
||||||
AnalyticsResult that = (AnalyticsResult) other;
|
AnalyticsResult that = (AnalyticsResult) other;
|
||||||
return Objects.equals(rowResults, that.rowResults)
|
return Objects.equals(rowResults, that.rowResults)
|
||||||
&& Objects.equals(phaseProgress, that.phaseProgress)
|
&& Objects.equals(phaseProgress, that.phaseProgress)
|
||||||
&& Objects.equals(inferenceModel, that.inferenceModel)
|
|
||||||
&& Objects.equals(memoryUsage, that.memoryUsage)
|
&& Objects.equals(memoryUsage, that.memoryUsage)
|
||||||
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
|
&& Objects.equals(outlierDetectionStats, that.outlierDetectionStats)
|
||||||
&& Objects.equals(classificationStats, that.classificationStats)
|
&& Objects.equals(classificationStats, that.classificationStats)
|
||||||
&& Objects.equals(modelSizeInfo, that.modelSizeInfo)
|
&& Objects.equals(modelSizeInfo, that.modelSizeInfo)
|
||||||
|
&& Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk)
|
||||||
&& Objects.equals(regressionStats, that.regressionStats);
|
&& Objects.equals(regressionStats, that.regressionStats);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
|
return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats,
|
||||||
regressionStats);
|
regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.process.results;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
|
public class TrainedModelDefinitionChunk implements ToXContentObject {
|
||||||
|
|
||||||
|
private static final ParseField DEFINITION = new ParseField("definition");
|
||||||
|
private static final ParseField DOC_NUM = new ParseField("doc_num");
|
||||||
|
private static final ParseField EOS = new ParseField("eos");
|
||||||
|
|
||||||
|
public static final ConstructingObjectParser<TrainedModelDefinitionChunk, Void> PARSER = new ConstructingObjectParser<>(
|
||||||
|
"chunked_trained_model_definition",
|
||||||
|
a -> new TrainedModelDefinitionChunk((String) a[0], (Integer) a[1], (Boolean) a[2]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(constructorArg(), DEFINITION);
|
||||||
|
PARSER.declareInt(constructorArg(), DOC_NUM);
|
||||||
|
PARSER.declareBoolean(optionalConstructorArg(), EOS);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final String definition;
|
||||||
|
private final int docNum;
|
||||||
|
private final Boolean eos;
|
||||||
|
|
||||||
|
public TrainedModelDefinitionChunk(String definition, int docNum, Boolean eos) {
|
||||||
|
this.definition = definition;
|
||||||
|
this.docNum = docNum;
|
||||||
|
this.eos = eos;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModelDefinitionDoc createTrainedModelDoc(String modelId) {
|
||||||
|
return new TrainedModelDefinitionDoc.Builder()
|
||||||
|
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
||||||
|
.setModelId(modelId)
|
||||||
|
.setDefinitionLength(definition.length())
|
||||||
|
.setDocNum(docNum)
|
||||||
|
.setCompressedString(definition)
|
||||||
|
.setEos(isEos())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isEos() {
|
||||||
|
return eos != null && eos;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(DEFINITION.getPreferredName(), definition);
|
||||||
|
builder.field(DOC_NUM.getPreferredName(), docNum);
|
||||||
|
if (eos != null) {
|
||||||
|
builder.field(EOS.getPreferredName(), eos);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
TrainedModelDefinitionChunk that = (TrainedModelDefinitionChunk) o;
|
||||||
|
return docNum == that.docNum
|
||||||
|
&& Objects.equals(definition, that.definition)
|
||||||
|
&& Objects.equals(eos, that.eos);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(definition, docNum, eos);
|
||||||
|
}
|
||||||
|
}
|
|
@ -33,6 +33,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version");
|
public static final ParseField COMPRESSION_VERSION = new ParseField("compression_version");
|
||||||
public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length");
|
public static final ParseField TOTAL_DEFINITION_LENGTH = new ParseField("total_definition_length");
|
||||||
public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length");
|
public static final ParseField DEFINITION_LENGTH = new ParseField("definition_length");
|
||||||
|
public static final ParseField EOS = new ParseField("eos");
|
||||||
|
|
||||||
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
|
||||||
public static final ObjectParser<TrainedModelDefinitionDoc.Builder, Void> LENIENT_PARSER = createParser(true);
|
public static final ObjectParser<TrainedModelDefinitionDoc.Builder, Void> LENIENT_PARSER = createParser(true);
|
||||||
|
@ -48,6 +49,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION);
|
parser.declareInt(TrainedModelDefinitionDoc.Builder::setCompressionVersion, COMPRESSION_VERSION);
|
||||||
parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH);
|
parser.declareLong(TrainedModelDefinitionDoc.Builder::setDefinitionLength, DEFINITION_LENGTH);
|
||||||
parser.declareLong(TrainedModelDefinitionDoc.Builder::setTotalDefinitionLength, TOTAL_DEFINITION_LENGTH);
|
parser.declareLong(TrainedModelDefinitionDoc.Builder::setTotalDefinitionLength, TOTAL_DEFINITION_LENGTH);
|
||||||
|
parser.declareBoolean(TrainedModelDefinitionDoc.Builder::setEos, EOS);
|
||||||
return parser;
|
return parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,23 +65,26 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
private final String compressedString;
|
private final String compressedString;
|
||||||
private final String modelId;
|
private final String modelId;
|
||||||
private final int docNum;
|
private final int docNum;
|
||||||
private final long totalDefinitionLength;
|
// for BWC
|
||||||
|
private final Long totalDefinitionLength;
|
||||||
private final long definitionLength;
|
private final long definitionLength;
|
||||||
private final int compressionVersion;
|
private final int compressionVersion;
|
||||||
|
private final boolean eos;
|
||||||
|
|
||||||
private TrainedModelDefinitionDoc(String compressedString,
|
private TrainedModelDefinitionDoc(String compressedString,
|
||||||
String modelId,
|
String modelId,
|
||||||
int docNum,
|
int docNum,
|
||||||
long totalDefinitionLength,
|
Long totalDefinitionLength,
|
||||||
long definitionLength,
|
long definitionLength,
|
||||||
int compressionVersion) {
|
int compressionVersion,
|
||||||
|
boolean eos) {
|
||||||
this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION);
|
this.compressedString = ExceptionsHelper.requireNonNull(compressedString, DEFINITION);
|
||||||
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
|
||||||
if (docNum < 0) {
|
if (docNum < 0) {
|
||||||
throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0");
|
throw new IllegalArgumentException("[doc_num] must be greater than or equal to 0");
|
||||||
}
|
}
|
||||||
this.docNum = docNum;
|
this.docNum = docNum;
|
||||||
if (totalDefinitionLength <= 0L) {
|
if (totalDefinitionLength != null && totalDefinitionLength <= 0L) {
|
||||||
throw new IllegalArgumentException("[total_definition_length] must be greater than 0");
|
throw new IllegalArgumentException("[total_definition_length] must be greater than 0");
|
||||||
}
|
}
|
||||||
this.totalDefinitionLength = totalDefinitionLength;
|
this.totalDefinitionLength = totalDefinitionLength;
|
||||||
|
@ -88,6 +93,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
}
|
}
|
||||||
this.definitionLength = definitionLength;
|
this.definitionLength = definitionLength;
|
||||||
this.compressionVersion = compressionVersion;
|
this.compressionVersion = compressionVersion;
|
||||||
|
this.eos = eos;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getCompressedString() {
|
public String getCompressedString() {
|
||||||
|
@ -102,7 +108,7 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
return docNum;
|
return docNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
public long getTotalDefinitionLength() {
|
public Long getTotalDefinitionLength() {
|
||||||
return totalDefinitionLength;
|
return totalDefinitionLength;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,16 +120,24 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
return compressionVersion;
|
return compressionVersion;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isEos() {
|
||||||
|
return eos;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getDocId() {
|
||||||
|
return docId(modelId, docNum);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
|
||||||
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
||||||
builder.field(DOC_NUM.getPreferredName(), docNum);
|
builder.field(DOC_NUM.getPreferredName(), docNum);
|
||||||
builder.field(TOTAL_DEFINITION_LENGTH.getPreferredName(), totalDefinitionLength);
|
|
||||||
builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength);
|
builder.field(DEFINITION_LENGTH.getPreferredName(), definitionLength);
|
||||||
builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion);
|
builder.field(COMPRESSION_VERSION.getPreferredName(), compressionVersion);
|
||||||
builder.field(DEFINITION.getPreferredName(), compressedString);
|
builder.field(DEFINITION.getPreferredName(), compressedString);
|
||||||
|
builder.field(EOS.getPreferredName(), eos);
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
@ -143,12 +157,13 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
Objects.equals(definitionLength, that.definitionLength) &&
|
Objects.equals(definitionLength, that.definitionLength) &&
|
||||||
Objects.equals(totalDefinitionLength, that.totalDefinitionLength) &&
|
Objects.equals(totalDefinitionLength, that.totalDefinitionLength) &&
|
||||||
Objects.equals(compressionVersion, that.compressionVersion) &&
|
Objects.equals(compressionVersion, that.compressionVersion) &&
|
||||||
|
Objects.equals(eos, that.eos) &&
|
||||||
Objects.equals(compressedString, that.compressedString);
|
Objects.equals(compressedString, that.compressedString);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(modelId, docNum, totalDefinitionLength, definitionLength, compressionVersion, compressedString);
|
return Objects.hash(modelId, docNum, definitionLength, totalDefinitionLength, compressionVersion, compressedString, eos);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
@ -156,9 +171,10 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
private String modelId;
|
private String modelId;
|
||||||
private String compressedString;
|
private String compressedString;
|
||||||
private int docNum;
|
private int docNum;
|
||||||
private long totalDefinitionLength;
|
private Long totalDefinitionLength;
|
||||||
private long definitionLength;
|
private long definitionLength;
|
||||||
private int compressionVersion;
|
private int compressionVersion;
|
||||||
|
private boolean eos;
|
||||||
|
|
||||||
public Builder setModelId(String modelId) {
|
public Builder setModelId(String modelId) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
|
@ -190,6 +206,11 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setEos(boolean eos) {
|
||||||
|
this.eos = eos;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public TrainedModelDefinitionDoc build() {
|
public TrainedModelDefinitionDoc build() {
|
||||||
return new TrainedModelDefinitionDoc(
|
return new TrainedModelDefinitionDoc(
|
||||||
this.compressedString,
|
this.compressedString,
|
||||||
|
@ -197,7 +218,8 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
|
||||||
this.docNum,
|
this.docNum,
|
||||||
this.totalDefinitionLength,
|
this.totalDefinitionLength,
|
||||||
this.definitionLength,
|
this.definitionLength,
|
||||||
this.compressionVersion);
|
this.compressionVersion,
|
||||||
|
this.eos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,10 +14,14 @@ import org.elasticsearch.ResourceAlreadyExistsException;
|
||||||
import org.elasticsearch.ResourceNotFoundException;
|
import org.elasticsearch.ResourceNotFoundException;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.action.DocWriteRequest;
|
import org.elasticsearch.action.DocWriteRequest;
|
||||||
|
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
||||||
|
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||||
|
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
|
||||||
import org.elasticsearch.action.bulk.BulkAction;
|
import org.elasticsearch.action.bulk.BulkAction;
|
||||||
import org.elasticsearch.action.bulk.BulkItemResponse;
|
import org.elasticsearch.action.bulk.BulkItemResponse;
|
||||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||||
import org.elasticsearch.action.bulk.BulkResponse;
|
import org.elasticsearch.action.bulk.BulkResponse;
|
||||||
|
import org.elasticsearch.action.index.IndexAction;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.MultiSearchAction;
|
import org.elasticsearch.action.search.MultiSearchAction;
|
||||||
import org.elasticsearch.action.search.MultiSearchRequest;
|
import org.elasticsearch.action.search.MultiSearchRequest;
|
||||||
|
@ -143,6 +147,74 @@ public class TrainedModelProvider {
|
||||||
storeTrainedModelAndDefinition(trainedModelConfig, listener);
|
storeTrainedModelAndDefinition(trainedModelConfig, listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig,
|
||||||
|
ActionListener<Boolean> listener) {
|
||||||
|
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
|
||||||
|
listener.onFailure(new ResourceAlreadyExistsException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
assert trainedModelConfig.getModelDefinition() == null;
|
||||||
|
|
||||||
|
executeAsyncWithOrigin(client,
|
||||||
|
ML_ORIGIN,
|
||||||
|
IndexAction.INSTANCE,
|
||||||
|
createRequest(trainedModelConfig.getModelId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelConfig),
|
||||||
|
ActionListener.wrap(
|
||||||
|
indexResponse -> listener.onResponse(true),
|
||||||
|
e -> {
|
||||||
|
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||||
|
listener.onFailure(new ResourceAlreadyExistsException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||||
|
} else {
|
||||||
|
listener.onFailure(
|
||||||
|
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
e,
|
||||||
|
trainedModelConfig.getModelId()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc, ActionListener<Void> listener) {
|
||||||
|
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelDefinitionDoc.getModelId())) {
|
||||||
|
listener.onFailure(new ResourceAlreadyExistsException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelDefinitionDoc.getModelId())));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
executeAsyncWithOrigin(client,
|
||||||
|
ML_ORIGIN,
|
||||||
|
IndexAction.INSTANCE,
|
||||||
|
createRequest(trainedModelDefinitionDoc.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelDefinitionDoc),
|
||||||
|
ActionListener.wrap(
|
||||||
|
indexResponse -> listener.onResponse(null),
|
||||||
|
e -> {
|
||||||
|
if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
|
||||||
|
listener.onFailure(new ResourceAlreadyExistsException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_DOC_EXISTS,
|
||||||
|
trainedModelDefinitionDoc.getModelId(),
|
||||||
|
trainedModelDefinitionDoc.getDocNum())));
|
||||||
|
} else {
|
||||||
|
listener.onFailure(
|
||||||
|
new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL,
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
e,
|
||||||
|
trainedModelDefinitionDoc.getModelId()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void refreshInferenceIndex(ActionListener<RefreshResponse> listener) {
|
||||||
|
executeAsyncWithOrigin(client,
|
||||||
|
ML_ORIGIN,
|
||||||
|
RefreshAction.INSTANCE,
|
||||||
|
new RefreshRequest(InferenceIndexConstants.INDEX_PATTERN),
|
||||||
|
listener);
|
||||||
|
}
|
||||||
|
|
||||||
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
|
private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfig,
|
||||||
ActionListener<Boolean> listener) {
|
ActionListener<Boolean> listener) {
|
||||||
|
|
||||||
|
@ -165,7 +237,8 @@ public class TrainedModelProvider {
|
||||||
.setCompressedString(chunkedStrings.get(i))
|
.setCompressedString(chunkedStrings.get(i))
|
||||||
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
|
||||||
.setDefinitionLength(chunkedStrings.get(i).length())
|
.setDefinitionLength(chunkedStrings.get(i).length())
|
||||||
.setTotalDefinitionLength(compressedString.length())
|
// If it is the last doc, it is the EOS
|
||||||
|
.setEos(i == chunkedStrings.size() - 1)
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
|
@ -265,6 +338,9 @@ public class TrainedModelProvider {
|
||||||
.unmappedType("long"))
|
.unmappedType("long"))
|
||||||
.request();
|
.request();
|
||||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
|
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
|
||||||
|
// TODO how could we stream in the model definition WHILE parsing it?
|
||||||
|
// This would reduce the overall memory usage as we won't have to load the whole compressed string
|
||||||
|
// XContentParser supports streams.
|
||||||
searchResponse -> {
|
searchResponse -> {
|
||||||
if (searchResponse.getHits().getHits().length == 0) {
|
if (searchResponse.getHits().getHits().length == 0) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
|
@ -274,19 +350,16 @@ public class TrainedModelProvider {
|
||||||
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
|
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
|
||||||
modelId,
|
modelId,
|
||||||
this::parseModelDefinitionDocLenientlyFromSource);
|
this::parseModelDefinitionDocLenientlyFromSource);
|
||||||
String compressedString = docs.stream()
|
try {
|
||||||
.map(TrainedModelDefinitionDoc::getCompressedString)
|
String compressedString = getDefinitionFromDocs(docs, modelId);
|
||||||
.collect(Collectors.joining());
|
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
|
||||||
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
compressedString,
|
||||||
listener.onFailure(ExceptionsHelper.serverError(
|
InferenceDefinition::fromXContent,
|
||||||
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
xContentRegistry);
|
||||||
return;
|
listener.onResponse(inferenceDefinition);
|
||||||
|
} catch (ElasticsearchException elasticsearchException) {
|
||||||
|
listener.onFailure(elasticsearchException);
|
||||||
}
|
}
|
||||||
InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate(
|
|
||||||
compressedString,
|
|
||||||
InferenceDefinition::fromXContent,
|
|
||||||
xContentRegistry);
|
|
||||||
listener.onResponse(inferenceDefinition);
|
|
||||||
},
|
},
|
||||||
e -> {
|
e -> {
|
||||||
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
||||||
|
@ -361,15 +434,14 @@ public class TrainedModelProvider {
|
||||||
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
|
List<TrainedModelDefinitionDoc> docs = handleSearchItems(multiSearchResponse.getResponses()[1],
|
||||||
modelId,
|
modelId,
|
||||||
this::parseModelDefinitionDocLenientlyFromSource);
|
this::parseModelDefinitionDocLenientlyFromSource);
|
||||||
String compressedString = docs.stream()
|
try {
|
||||||
.map(TrainedModelDefinitionDoc::getCompressedString)
|
String compressedString = getDefinitionFromDocs(docs, modelId);
|
||||||
.collect(Collectors.joining());
|
builder.setDefinitionFromString(compressedString);
|
||||||
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
} catch (ElasticsearchException elasticsearchException) {
|
||||||
listener.onFailure(ExceptionsHelper.serverError(
|
listener.onFailure(elasticsearchException);
|
||||||
Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
builder.setDefinitionFromString(compressedString);
|
|
||||||
} catch (ResourceNotFoundException ex) {
|
} catch (ResourceNotFoundException ex) {
|
||||||
listener.onFailure(new ResourceNotFoundException(
|
listener.onFailure(new ResourceNotFoundException(
|
||||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||||
|
@ -806,6 +878,26 @@ public class TrainedModelProvider {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static String getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId) throws ElasticsearchException {
|
||||||
|
String compressedString = docs.stream()
|
||||||
|
.map(TrainedModelDefinitionDoc::getCompressedString)
|
||||||
|
.collect(Collectors.joining());
|
||||||
|
// BWC for when we tracked the total definition length
|
||||||
|
// TODO: remove in 9
|
||||||
|
if (docs.get(0).getTotalDefinitionLength() != null) {
|
||||||
|
if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) {
|
||||||
|
throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
TrainedModelDefinitionDoc lastDoc = docs.get(docs.size() - 1);
|
||||||
|
// Either we are missing the last doc, or some previous doc
|
||||||
|
if(lastDoc.isEos() == false || lastDoc.getDocNum() != docs.size() - 1) {
|
||||||
|
throw ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return compressedString;
|
||||||
|
}
|
||||||
|
|
||||||
static List<String> chunkStringWithSize(String str, int chunkSize) {
|
static List<String> chunkStringWithSize(String str, int chunkSize) {
|
||||||
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
|
List<String> subStrings = new ArrayList<>((int)Math.ceil(str.length()/(double)chunkSize));
|
||||||
for (int i = 0; i < str.length();i += chunkSize) {
|
for (int i = 0; i < str.length();i += chunkSize) {
|
||||||
|
@ -836,14 +928,18 @@ public class TrainedModelProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private IndexRequest createRequest(String docId, String index, ToXContentObject body) {
|
||||||
|
return createRequest(new IndexRequest(index), docId, body);
|
||||||
|
}
|
||||||
|
|
||||||
private IndexRequest createRequest(String docId, ToXContentObject body) {
|
private IndexRequest createRequest(String docId, ToXContentObject body) {
|
||||||
|
return createRequest(new IndexRequest(), docId, body);
|
||||||
|
}
|
||||||
|
|
||||||
|
private IndexRequest createRequest(IndexRequest request, String docId, ToXContentObject body) {
|
||||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||||
XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
|
XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS);
|
||||||
|
return request.opType(DocWriteRequest.OpType.CREATE).id(docId).source(source);
|
||||||
return new IndexRequest()
|
|
||||||
.opType(DocWriteRequest.OpType.CREATE)
|
|
||||||
.id(docId)
|
|
||||||
.source(source);
|
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
// This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again
|
// This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again
|
||||||
// that is not the users fault. We did something wrong and should throw.
|
// that is not the users fault. We did something wrong and should throw.
|
||||||
|
|
|
@ -5,32 +5,20 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||||
|
|
||||||
import org.elasticsearch.Version;
|
|
||||||
import org.elasticsearch.action.ActionListener;
|
|
||||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
||||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
|
||||||
import org.elasticsearch.license.License;
|
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
|
||||||
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
||||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
import org.elasticsearch.xpack.ml.extractor.MultiField;
|
|
||||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -38,20 +26,14 @@ import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.InOrder;
|
import org.mockito.InOrder;
|
||||||
import org.mockito.Mockito;
|
import org.mockito.Mockito;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.contains;
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasKey;
|
|
||||||
import static org.hamcrest.Matchers.startsWith;
|
|
||||||
import static org.mockito.Matchers.any;
|
import static org.mockito.Matchers.any;
|
||||||
import static org.mockito.Matchers.eq;
|
import static org.mockito.Matchers.eq;
|
||||||
import static org.mockito.Mockito.doAnswer;
|
|
||||||
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
|
@ -156,90 +138,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
|
|
||||||
givenDataFrameRows(0);
|
|
||||||
|
|
||||||
doAnswer(invocationOnMock -> {
|
|
||||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
|
||||||
storeListener.onResponse(true);
|
|
||||||
return null;
|
|
||||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
|
||||||
|
|
||||||
List<ExtractedField> extractedFieldList = new ArrayList<>(3);
|
|
||||||
extractedFieldList.add(new DocValueField("foo", Collections.emptySet()));
|
|
||||||
extractedFieldList.add(new MultiField("bar", new DocValueField("bar.keyword", Collections.emptySet())));
|
|
||||||
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
|
|
||||||
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
|
|
||||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
|
|
||||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
|
|
||||||
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
|
|
||||||
|
|
||||||
resultProcessor.process(process);
|
|
||||||
resultProcessor.awaitForCompletion();
|
|
||||||
|
|
||||||
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
|
|
||||||
verify(trainedModelProvider).storeTrainedModel(storedModelCaptor.capture(), any(ActionListener.class));
|
|
||||||
|
|
||||||
TrainedModelConfig storedModel = storedModelCaptor.getValue();
|
|
||||||
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
|
|
||||||
assertThat(storedModel.getModelId(), containsString(JOB_ID));
|
|
||||||
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
|
|
||||||
assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME));
|
|
||||||
assertThat(storedModel.getTags(), contains(JOB_ID));
|
|
||||||
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
|
||||||
assertThat(storedModel.getModelDefinition(), equalTo(inferenceModel.build()));
|
|
||||||
assertThat(storedModel.getDefaultFieldMap(), equalTo(Collections.singletonMap("bar", "bar.keyword")));
|
|
||||||
assertThat(storedModel.getInput().getFieldNames(), equalTo(Arrays.asList("bar.keyword", "baz")));
|
|
||||||
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
|
|
||||||
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
|
|
||||||
if (targetType.equals(TargetType.CLASSIFICATION)) {
|
|
||||||
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
|
|
||||||
} else {
|
|
||||||
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
|
|
||||||
}
|
|
||||||
Map<String, Object> metadata = storedModel.getMetadata();
|
|
||||||
assertThat(metadata.size(), equalTo(1));
|
|
||||||
assertThat(metadata, hasKey("analytics_config"));
|
|
||||||
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
|
|
||||||
true);
|
|
||||||
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
|
|
||||||
|
|
||||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
|
||||||
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
|
|
||||||
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
|
|
||||||
Mockito.verifyNoMoreInteractions(auditor);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testProcess_GivenInferenceModelFailedToStore() {
|
|
||||||
givenDataFrameRows(0);
|
|
||||||
|
|
||||||
doAnswer(invocationOnMock -> {
|
|
||||||
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
|
||||||
storeListener.onFailure(new RuntimeException("some failure"));
|
|
||||||
return null;
|
|
||||||
}).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class));
|
|
||||||
|
|
||||||
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
|
|
||||||
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
|
|
||||||
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null, null)));
|
|
||||||
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
||||||
|
|
||||||
resultProcessor.process(process);
|
|
||||||
resultProcessor.awaitForCompletion();
|
|
||||||
|
|
||||||
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
|
|
||||||
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
|
||||||
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
|
|
||||||
assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID));
|
|
||||||
Mockito.verifyNoMoreInteractions(auditor);
|
|
||||||
|
|
||||||
assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
|
|
||||||
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void givenProcessResults(List<AnalyticsResult> results) {
|
private void givenProcessResults(List<AnalyticsResult> results) {
|
||||||
when(process.readAnalyticsResults()).thenReturn(results.iterator());
|
when(process.readAnalyticsResults()).thenReturn(results.iterator());
|
||||||
}
|
}
|
||||||
|
@ -256,7 +154,6 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private AnalyticsResultProcessor createResultProcessor(List<ExtractedField> fieldNames) {
|
private AnalyticsResultProcessor createResultProcessor(List<ExtractedField> fieldNames) {
|
||||||
|
|
||||||
return new AnalyticsResultProcessor(analyticsConfig,
|
return new AnalyticsResultProcessor(analyticsConfig,
|
||||||
dataFrameRowsJoiner,
|
dataFrameRowsJoiner,
|
||||||
statsHolder,
|
statsHolder,
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||||
|
|
||||||
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.license.License;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.security.user.XPackUser;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||||
|
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.contains;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasKey;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
import static org.mockito.Matchers.any;
|
||||||
|
import static org.mockito.Matchers.eq;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
|
||||||
|
public class ChunkedTrainedModelPersisterTests extends ESTestCase {
|
||||||
|
|
||||||
|
private static final String JOB_ID = "analytics-result-processor-tests";
|
||||||
|
private static final String JOB_DESCRIPTION = "This describes the job of these tests";
|
||||||
|
|
||||||
|
private TrainedModelProvider trainedModelProvider;
|
||||||
|
private DataFrameAnalyticsAuditor auditor;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUpMocks() {
|
||||||
|
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||||
|
auditor = mock(DataFrameAnalyticsAuditor.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testPersistAllDocs() {
|
||||||
|
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
|
||||||
|
.setId(JOB_ID)
|
||||||
|
.setDescription(JOB_DESCRIPTION)
|
||||||
|
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
|
||||||
|
.setDest(new DataFrameAnalyticsDest("my_dest", null))
|
||||||
|
.setAnalysis(randomBoolean() ? new Regression("foo") : new Classification("foo"))
|
||||||
|
.build();
|
||||||
|
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
|
||||||
|
|
||||||
|
doAnswer(invocationOnMock -> {
|
||||||
|
ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
|
||||||
|
storeListener.onResponse(true);
|
||||||
|
return null;
|
||||||
|
}).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class));
|
||||||
|
|
||||||
|
doAnswer(invocationOnMock -> {
|
||||||
|
ActionListener<Void> storeListener = (ActionListener<Void>) invocationOnMock.getArguments()[1];
|
||||||
|
storeListener.onResponse(null);
|
||||||
|
return null;
|
||||||
|
}).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class));
|
||||||
|
|
||||||
|
ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig);
|
||||||
|
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||||
|
TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false);
|
||||||
|
TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true);
|
||||||
|
|
||||||
|
resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo);
|
||||||
|
resultProcessor.createAndIndexInferenceModelDoc(chunk1);
|
||||||
|
resultProcessor.createAndIndexInferenceModelDoc(chunk2);
|
||||||
|
|
||||||
|
ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
|
||||||
|
verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class));
|
||||||
|
|
||||||
|
ArgumentCaptor<TrainedModelDefinitionDoc> storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class);
|
||||||
|
verify(trainedModelProvider, times(2))
|
||||||
|
.storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class));
|
||||||
|
|
||||||
|
TrainedModelConfig storedModel = storedModelCaptor.getValue();
|
||||||
|
assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
|
||||||
|
assertThat(storedModel.getModelId(), containsString(JOB_ID));
|
||||||
|
assertThat(storedModel.getVersion(), equalTo(Version.CURRENT));
|
||||||
|
assertThat(storedModel.getCreatedBy(), equalTo(XPackUser.NAME));
|
||||||
|
assertThat(storedModel.getTags(), contains(JOB_ID));
|
||||||
|
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
|
||||||
|
assertThat(storedModel.getModelDefinition(), is(nullValue()));
|
||||||
|
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
|
||||||
|
assertThat(storedModel.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
|
||||||
|
if (analyticsConfig.getAnalysis() instanceof Classification) {
|
||||||
|
assertThat(storedModel.getInferenceConfig().getName(), equalTo("classification"));
|
||||||
|
} else {
|
||||||
|
assertThat(storedModel.getInferenceConfig().getName(), equalTo("regression"));
|
||||||
|
}
|
||||||
|
Map<String, Object> metadata = storedModel.getMetadata();
|
||||||
|
assertThat(metadata.size(), equalTo(1));
|
||||||
|
assertThat(metadata, hasKey("analytics_config"));
|
||||||
|
Map<String, Object> analyticsConfigAsMap = XContentHelper.convertToMap(JsonXContent.jsonXContent, analyticsConfig.toString(),
|
||||||
|
true);
|
||||||
|
assertThat(analyticsConfigAsMap, equalTo(metadata.get("analytics_config")));
|
||||||
|
|
||||||
|
TrainedModelDefinitionDoc storedDoc1 = storedDocCapture.getAllValues().get(0);
|
||||||
|
assertThat(storedDoc1.getDocNum(), equalTo(0));
|
||||||
|
TrainedModelDefinitionDoc storedDoc2 = storedDocCapture.getAllValues().get(1);
|
||||||
|
assertThat(storedDoc2.getDocNum(), equalTo(1));
|
||||||
|
|
||||||
|
assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId()));
|
||||||
|
assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId()));
|
||||||
|
|
||||||
|
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
|
||||||
|
verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
|
||||||
|
assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));
|
||||||
|
Mockito.verifyNoMoreInteractions(auditor);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ChunkedTrainedModelPersister createChunkedTrainedModelPersister(List<ExtractedField> fieldNames,
|
||||||
|
DataFrameAnalyticsConfig analyticsConfig) {
|
||||||
|
return new ChunkedTrainedModelPersister(trainedModelProvider,
|
||||||
|
analyticsConfig,
|
||||||
|
auditor,
|
||||||
|
(unused)->{},
|
||||||
|
new ExtractedFields(fieldNames, Collections.emptyMap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -20,8 +20,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierD
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
|
|
||||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
|
||||||
|
@ -46,21 +44,18 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
||||||
protected AnalyticsResult createTestInstance() {
|
protected AnalyticsResult createTestInstance() {
|
||||||
RowResults rowResults = null;
|
RowResults rowResults = null;
|
||||||
PhaseProgress phaseProgress = null;
|
PhaseProgress phaseProgress = null;
|
||||||
TrainedModelDefinition.Builder inferenceModel = null;
|
|
||||||
MemoryUsage memoryUsage = null;
|
MemoryUsage memoryUsage = null;
|
||||||
OutlierDetectionStats outlierDetectionStats = null;
|
OutlierDetectionStats outlierDetectionStats = null;
|
||||||
ClassificationStats classificationStats = null;
|
ClassificationStats classificationStats = null;
|
||||||
RegressionStats regressionStats = null;
|
RegressionStats regressionStats = null;
|
||||||
ModelSizeInfo modelSizeInfo = null;
|
ModelSizeInfo modelSizeInfo = null;
|
||||||
|
TrainedModelDefinitionChunk trainedModelDefinitionChunk = null;
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
rowResults = RowResultsTests.createRandom();
|
rowResults = RowResultsTests.createRandom();
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
|
phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
|
||||||
}
|
}
|
||||||
if (randomBoolean()) {
|
|
||||||
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder();
|
|
||||||
}
|
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
memoryUsage = MemoryUsageTests.createRandom();
|
memoryUsage = MemoryUsageTests.createRandom();
|
||||||
}
|
}
|
||||||
|
@ -76,8 +71,12 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
|
||||||
if (randomBoolean()) {
|
if (randomBoolean()) {
|
||||||
modelSizeInfo = ModelSizeInfoTests.createRandom();
|
modelSizeInfo = ModelSizeInfoTests.createRandom();
|
||||||
}
|
}
|
||||||
return new AnalyticsResult(rowResults, phaseProgress, inferenceModel, memoryUsage, outlierDetectionStats,
|
if (randomBoolean()) {
|
||||||
classificationStats, regressionStats, modelSizeInfo);
|
String def = randomAlphaOfLengthBetween(100, 1000);
|
||||||
|
trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean());
|
||||||
|
}
|
||||||
|
return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats,
|
||||||
|
classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue