[7.x][ML] Handle data frame analytics state spreading over multiple docs (#62564) (#62824)

When state persistence was first implemented for data frame analytics
we had the assumption that state would always fit in a single document.
However this is not the case any more.

This commit adds handling of state that spreads over multiple documents.

Backport of #62564
This commit is contained in:
Dimitris Athanasiou 2020-09-23 16:16:34 +03:00 committed by GitHub
parent e3d5915566
commit 7de5201291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 117 additions and 73 deletions

View File

@ -55,7 +55,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors"); public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; private static final String STATE_DOC_ID_INFIX = "_classification_state#";
private static final String NUM_CLASSES = "num_classes"; private static final String NUM_CLASSES = "num_classes";
@ -413,8 +413,8 @@ public class Classification implements DataFrameAnalysis {
} }
@Override @Override
public String getStateDocId(String jobId) { public String getStateDocIdPrefix(String jobId) {
return jobId + STATE_DOC_ID_SUFFIX; return jobId + STATE_DOC_ID_INFIX;
} }
@Override @Override
@ -439,7 +439,7 @@ public class Classification implements DataFrameAnalysis {
} }
public static String extractJobIdFromStateDoc(String stateDocId) { public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX); int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_INFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex); return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
} }

View File

@ -63,9 +63,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
boolean persistsState(); boolean persistsState();
/** /**
* Returns the document id for the analysis state * Returns the document id prefix for the analysis state
*/ */
String getStateDocId(String jobId); String getStateDocIdPrefix(String jobId);
/** /**
* Returns the progress phases the analysis goes through in order * Returns the progress phases the analysis goes through in order

View File

@ -264,7 +264,7 @@ public class OutlierDetection implements DataFrameAnalysis {
} }
@Override @Override
public String getStateDocId(String jobId) { public String getStateDocIdPrefix(String jobId) {
throw new UnsupportedOperationException("Outlier detection does not support state"); throw new UnsupportedOperationException("Outlier detection does not support state");
} }

View File

@ -51,7 +51,7 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter"); public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors"); public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1"; private static final String STATE_DOC_ID_INFIX = "_regression_state#";
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false); private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
@ -319,8 +319,8 @@ public class Regression implements DataFrameAnalysis {
} }
@Override @Override
public String getStateDocId(String jobId) { public String getStateDocIdPrefix(String jobId) {
return jobId + STATE_DOC_ID_SUFFIX; return jobId + STATE_DOC_ID_INFIX;
} }
@Override @Override
@ -342,7 +342,7 @@ public class Regression implements DataFrameAnalysis {
} }
public static String extractJobIdFromStateDoc(String stateDocId) { public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX); int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_INFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex); return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
} }

View File

@ -448,7 +448,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Classification classification = createRandom(); Classification classification = createRandom();
assertThat(classification.persistsState(), is(true)); assertThat(classification.persistsState(), is(true));
String randomId = randomAlphaOfLength(10); String randomId = randomAlphaOfLength(10);
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1")); assertThat(classification.getStateDocIdPrefix(randomId), equalTo(randomId + "_classification_state#"));
} }
public void testExtractJobIdFromStateDoc() { public void testExtractJobIdFromStateDoc() {

View File

@ -120,7 +120,7 @@ public class OutlierDetectionTests extends AbstractBWCSerializationTestCase<Outl
public void testGetStateDocId() { public void testGetStateDocId() {
OutlierDetection outlierDetection = createRandom(); OutlierDetection outlierDetection = createRandom();
assertThat(outlierDetection.persistsState(), is(false)); assertThat(outlierDetection.persistsState(), is(false));
expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo")); expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocIdPrefix("foo"));
} }
public void testInferenceConfig() { public void testInferenceConfig() {

View File

@ -331,7 +331,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
Regression regression = createRandom(); Regression regression = createRandom();
assertThat(regression.persistsState(), is(true)); assertThat(regression.persistsState(), is(true));
String randomId = randomAlphaOfLength(10); String randomId = randomAlphaOfLength(10);
assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1")); assertThat(regression.getStateDocIdPrefix(randomId), equalTo(randomId + "_regression_state#"));
} }
public void testExtractJobIdFromStateDoc() { public void testExtractJobIdFromStateDoc() {

View File

@ -649,8 +649,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
new BlackHoleAutodetectProcess(job.getId(), onProcessCrash); new BlackHoleAutodetectProcess(job.getId(), onProcessCrash);
// factor of 1.0 makes renormalization a no-op // factor of 1.0 makes renormalization a no-op
normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0); normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0);
analyticsProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null; analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null; memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
} }
NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory,
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)); threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));

View File

@ -28,6 +28,7 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest; import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
@ -57,8 +58,6 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@ -248,17 +247,46 @@ public class TransportDeleteDataFrameAnalyticsAction
DataFrameAnalyticsConfig config, DataFrameAnalyticsConfig config,
TimeValue timeout, TimeValue timeout,
ActionListener<BulkByScrollResponse> listener) { ActionListener<BulkByScrollResponse> listener) {
List<String> ids = new ArrayList<>(); ActionListener<Boolean> deleteModelStateListener = ActionListener.wrap(
ids.add(StoredProgress.documentId(config.getId())); r -> executeDeleteByQuery(
if (config.getAnalysis().persistsState()) { parentTaskClient,
ids.add(config.getAnalysis().getStateDocId(config.getId())); AnomalyDetectorsIndex.jobStateIndexPattern(),
QueryBuilders.idsQuery().addIds(StoredProgress.documentId(config.getId())),
timeout,
listener
)
, listener::onFailure
);
deleteModelState(parentTaskClient, config, timeout, 1, deleteModelStateListener);
} }
private void deleteModelState(ParentTaskAssigningClient parentTaskClient,
DataFrameAnalyticsConfig config,
TimeValue timeout,
int docNum,
ActionListener<Boolean> listener) {
if (config.getAnalysis().persistsState() == false) {
listener.onResponse(true);
return;
}
IdsQueryBuilder query = QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocIdPrefix(config.getId()) + docNum);
executeDeleteByQuery( executeDeleteByQuery(
parentTaskClient, parentTaskClient,
AnomalyDetectorsIndex.jobStateIndexPattern(), AnomalyDetectorsIndex.jobStateIndexPattern(),
QueryBuilders.idsQuery().addIds(ids.toArray(new String[0])), query,
timeout, timeout,
listener ActionListener.wrap(
response -> {
if (response.getDeleted() > 0) {
deleteModelState(parentTaskClient, config, timeout, docNum + 1, listener);
return;
}
listener.onResponse(true);
},
listener::onFailure
)
); );
} }

View File

@ -5,7 +5,7 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.process; package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.client.Client;
import org.elasticsearch.xpack.ml.process.NativeProcess; import org.elasticsearch.xpack.ml.process.NativeProcess;
import java.io.IOException; import java.io.IOException;
@ -41,7 +41,8 @@ public interface AnalyticsProcess<ProcessResult> extends NativeProcess {
/** /**
* Restores the model state from a previously persisted one * Restores the model state from a previously persisted one
* @param state the state to restore * @param client the client to use for fetching the state documents
* @param stateDocIdPrefix the prefix of ids of the state documents
*/ */
void restoreState(BytesReference state) throws IOException; void restoreState(Client client, String stateDocIdPrefix) throws IOException;
} }

View File

@ -5,8 +5,6 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.process; package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@ -19,12 +17,12 @@ public interface AnalyticsProcessFactory<ProcessResult> {
* *
* @param config The data frame analytics config * @param config The data frame analytics config
* @param analyticsProcessConfig The process configuration * @param analyticsProcessConfig The process configuration
* @param state The state document to restore from if there is one available * @param hasState Whether there is state to restore from
* @param executorService Executor service used to start the async tasks a job needs to operate the analytical process * @param executorService Executor service used to start the async tasks a job needs to operate the analytical process
* @param onProcessCrash Callback to execute if the process stops unexpectedly * @param onProcessCrash Callback to execute if the process stops unexpectedly
* @return The process * @return The process
*/ */
AnalyticsProcess<ProcessResult> createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig, AnalyticsProcess<ProcessResult> createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
@Nullable BytesReference state, ExecutorService executorService, boolean hasState, ExecutorService executorService,
Consumer<String> onProcessCrash); Consumer<String> onProcessCrash);
} }

View File

@ -15,13 +15,10 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.MlStatsIndex;
@ -139,11 +136,11 @@ public class AnalyticsProcessManager {
} }
// Fetch existing model state (if any) // Fetch existing model state (if any)
BytesReference state = getModelState(config); final boolean hasState = hasModelState(config);
boolean isProcessStarted; boolean isProcessStarted;
try { try {
isProcessStarted = processContext.startProcess(dataExtractorFactory, task, state); isProcessStarted = processContext.startProcess(dataExtractorFactory, task, hasState);
} catch (Exception e) { } catch (Exception e) {
processContext.stop(); processContext.stop();
task.setFailed(processContext.getFailureReason() == null ? task.setFailed(processContext.getFailureReason() == null ?
@ -153,7 +150,7 @@ public class AnalyticsProcessManager {
if (isProcessStarted) { if (isProcessStarted) {
executorServiceForProcess.execute(() -> processContext.resultProcessor.get().process(processContext.process.get())); executorServiceForProcess.execute(() -> processContext.resultProcessor.get().process(processContext.process.get()));
executorServiceForProcess.execute(() -> processData(task, processContext, state)); executorServiceForProcess.execute(() -> processData(task, processContext, hasState));
} else { } else {
processContextByAllocation.remove(task.getAllocationId()); processContextByAllocation.remove(task.getAllocationId());
auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_ANALYSIS); auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_ANALYSIS);
@ -162,23 +159,22 @@ public class AnalyticsProcessManager {
}); });
} }
@Nullable private boolean hasModelState(DataFrameAnalyticsConfig config) {
private BytesReference getModelState(DataFrameAnalyticsConfig config) {
if (config.getAnalysis().persistsState() == false) { if (config.getAnalysis().persistsState() == false) {
return null; return false;
} }
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1) .setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId()))) .setFetchSource(false)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocIdPrefix(config.getId()) + "1"))
.get(); .get();
SearchHit[] hits = searchResponse.getHits().getHits(); return searchResponse.getHits().getHits().length == 1;
return hits.length == 0 ? null : hits[0].getSourceRef();
} }
} }
private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, BytesReference state) { private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, boolean hasState) {
LOGGER.info("[{}] Started loading data", processContext.config.getId()); LOGGER.info("[{}] Started loading data", processContext.config.getId());
auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA)); auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA));
@ -193,7 +189,7 @@ public class AnalyticsProcessManager {
process.writeEndOfDataMessage(); process.writeEndOfDataMessage();
process.flushStream(); process.flushStream();
restoreState(task, config, state, process); restoreState(task, config, process, hasState);
LOGGER.info("[{}] Started analyzing", processContext.config.getId()); LOGGER.info("[{}] Started analyzing", processContext.config.getId());
auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_ANALYZING)); auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_ANALYZING));
@ -297,14 +293,14 @@ public class AnalyticsProcessManager {
process.writeRecord(headerRecord); process.writeRecord(headerRecord);
} }
private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state, private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, AnalyticsProcess<AnalyticsResult> process,
AnalyticsProcess<AnalyticsResult> process) { boolean hasState) {
if (config.getAnalysis().persistsState() == false) { if (config.getAnalysis().persistsState() == false) {
LOGGER.debug("[{}] Analysis does not support state", config.getId()); LOGGER.debug("[{}] Analysis does not support state", config.getId());
return; return;
} }
if (state == null) { if (hasState == false) {
LOGGER.debug("[{}] No model state available to restore", config.getId()); LOGGER.debug("[{}] No model state available to restore", config.getId());
return; return;
} }
@ -313,7 +309,7 @@ public class AnalyticsProcessManager {
auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE); auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE);
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
process.restoreState(state); process.restoreState(client, config.getAnalysis().getStateDocIdPrefix(config.getId()));
} catch (Exception e) { } catch (Exception e) {
LOGGER.error(new ParameterizedMessage("[{}] Failed to restore state", process.getConfig().jobId()), e); LOGGER.error(new ParameterizedMessage("[{}] Failed to restore state", process.getConfig().jobId()), e);
task.setFailed(ExceptionsHelper.serverError("Failed to restore state: " + e.getMessage())); task.setFailed(ExceptionsHelper.serverError("Failed to restore state: " + e.getMessage()));
@ -321,9 +317,9 @@ public class AnalyticsProcessManager {
} }
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) { AnalyticsProcessConfig analyticsProcessConfig, boolean hasState) {
AnalyticsProcess<AnalyticsResult> process = AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(
processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task)); config, analyticsProcessConfig, hasState, executorServiceForProcess, onProcessCrash(task));
if (process.isProcessAlive() == false) { if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start data frame analytics process"); throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
} }
@ -467,7 +463,7 @@ public class AnalyticsProcessManager {
*/ */
synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory,
DataFrameAnalyticsTask task, DataFrameAnalyticsTask task,
@Nullable BytesReference state) { boolean hasState) {
if (task.isStopping()) { if (task.isStopping()) {
// The job was stopped before we started the process so no need to start it // The job was stopped before we started the process so no need to start it
return false; return false;
@ -483,7 +479,7 @@ public class AnalyticsProcessManager {
LOGGER.info("[{}] no data found to analyze. Will not start analytics native process.", config.getId()); LOGGER.info("[{}] no data found to analyze. Will not start analytics native process.", config.getId());
return false; return false;
} }
process.set(createProcess(task, config, analyticsProcessConfig, state)); process.set(createProcess(task, config, analyticsProcessConfig, hasState));
resultProcessor.set(createResultProcessor(task, dataExtractorFactory)); resultProcessor.set(createResultProcessor(task, dataExtractorFactory));
return true; return true;
} }

View File

@ -85,7 +85,7 @@ public class MemoryUsageEstimationProcessManager {
processFactory.createAnalyticsProcess( processFactory.createAnalyticsProcess(
config, config,
processConfig, processConfig,
null, false,
executorServiceForProcess, executorServiceForProcess,
// The handler passed here will never be called as AbstractNativeProcess.detectCrash method returns early when // The handler passed here will never be called as AbstractNativeProcess.detectCrash method returns early when
// (processInStream == null) which is the case for MemoryUsageEstimationProcess. // (processInStream == null) which is the case for MemoryUsageEstimationProcess.

View File

@ -5,8 +5,15 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.process; package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.process.ProcessPipes; import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper; import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper;
@ -21,6 +28,8 @@ import java.util.function.Consumer;
public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<AnalyticsResult> { public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<AnalyticsResult> {
private static final Logger logger = LogManager.getLogger(NativeAnalyticsProcess.class);
private static final String NAME = "analytics"; private static final String NAME = "analytics";
private final AnalyticsProcessConfig config; private final AnalyticsProcessConfig config;
@ -55,10 +64,26 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
} }
@Override @Override
public void restoreState(BytesReference state) throws IOException { public void restoreState(Client client, String stateDocIdPrefix) throws IOException {
Objects.requireNonNull(state); Objects.requireNonNull(stateDocIdPrefix);
try (OutputStream restoreStream = processRestoreStream()) { try (OutputStream restoreStream = processRestoreStream()) {
StateToProcessWriterHelper.writeStateToStream(state, restoreStream); int docNum = 0;
while (true) {
if (isProcessKilled()) {
return;
}
// We fetch the documents one at a time because all together they can amount to too much memory
SearchResponse stateResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(stateDocIdPrefix + ++docNum)).get();
if (stateResponse.getHits().getHits().length == 0) {
break;
}
SearchHit stateDoc = stateResponse.getHits().getAt(0);
logger.debug(() -> new ParameterizedMessage("[{}] Restoring state document [{}]", config.jobId(), stateDoc.getId()));
StateToProcessWriterHelper.writeStateToStream(stateDoc.getSourceRef(), restoreStream);
}
} }
} }
} }

View File

@ -8,8 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -70,12 +68,12 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
@Override @Override
public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig, public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
@Nullable BytesReference state, ExecutorService executorService, boolean hasState, ExecutorService executorService,
Consumer<String> onProcessCrash) { Consumer<String> onProcessCrash) {
String jobId = config.getId(); String jobId = config.getId();
List<Path> filesToDelete = new ArrayList<>(); List<Path> filesToDelete = new ArrayList<>();
ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId, ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId,
false, true, true, state != null, config.getAnalysis().persistsState()); false, true, true, hasState, config.getAnalysis().persistsState());
// The extra 2 are for the checksum and the control field // The extra 2 are for the checksum and the control field
int numberOfFields = analyticsProcessConfig.cols() + 2; int numberOfFields = analyticsProcessConfig.cols() + 2;

View File

@ -5,7 +5,7 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.process; package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.process.ProcessPipes; import org.elasticsearch.xpack.ml.process.ProcessPipes;
@ -32,7 +32,7 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
} }
@Override @Override
public void restoreState(BytesReference state) { public void restoreState(Client client, String stateDocIdPrefix) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
} }

View File

@ -8,8 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.core.internal.io.IOUtils;
@ -60,7 +58,7 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce
public NativeMemoryUsageEstimationProcess createAnalyticsProcess( public NativeMemoryUsageEstimationProcess createAnalyticsProcess(
DataFrameAnalyticsConfig config, DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig, AnalyticsProcessConfig analyticsProcessConfig,
@Nullable BytesReference state, boolean hasState,
ExecutorService executorService, ExecutorService executorService,
Consumer<String> onProcessCrash) { Consumer<String> onProcessCrash) {
List<Path> filesToDelete = new ArrayList<>(); List<Path> filesToDelete = new ArrayList<>();

View File

@ -92,7 +92,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
when(process.isProcessAlive()).thenReturn(true); when(process.isProcessAlive()).thenReturn(true);
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator()); when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
processFactory = mock(AnalyticsProcessFactory.class); processFactory = mock(AnalyticsProcessFactory.class);
when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process); when(processFactory.createAnalyticsProcess(any(), any(), anyBoolean(), any(), any())).thenReturn(process);
auditor = mock(DataFrameAnalyticsAuditor.class); auditor = mock(DataFrameAnalyticsAuditor.class);
trainedModelProvider = mock(TrainedModelProvider.class); trainedModelProvider = mock(TrainedModelProvider.class);
@ -226,7 +226,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig); AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
processContext.stop(); processContext.stop();
assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(false)); assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(false));
InOrder inOrder = inOrder(dataExtractor, process, task); InOrder inOrder = inOrder(dataExtractor, process, task);
inOrder.verify(task).isStopping(); inOrder.verify(task).isStopping();
@ -237,7 +237,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS)); when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig); AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(false)); assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(false));
InOrder inOrder = inOrder(dataExtractor, process, task); InOrder inOrder = inOrder(dataExtractor, process, task);
inOrder.verify(task).isStopping(); inOrder.verify(task).isStopping();
@ -248,7 +248,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
public void testProcessContext_StartAndStop() throws Exception { public void testProcessContext_StartAndStop() throws Exception {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig); AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(true)); assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(true));
processContext.stop(); processContext.stop();
InOrder inOrder = inOrder(dataExtractor, process, task); InOrder inOrder = inOrder(dataExtractor, process, task);

View File

@ -66,7 +66,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
process = mock(AnalyticsProcess.class); process = mock(AnalyticsProcess.class);
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator()); when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
processFactory = mock(AnalyticsProcessFactory.class); processFactory = mock(AnalyticsProcessFactory.class);
when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process); when(processFactory.createAnalyticsProcess(any(), any(), anyBoolean(), any(), any())).thenReturn(process);
dataExtractor = mock(DataFrameDataExtractor.class); dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);