[7.x][ML] Handle data frame analytics state spreading over multiple docs (#62564) (#62824)

When state persistence was first implemented for data frame analytics
we had the assumption that state would always fit in a single document.
However this is not the case any more.

This commit adds handling of state that spreads over multiple documents.

Backport of #62564
This commit is contained in:
Dimitris Athanasiou 2020-09-23 16:16:34 +03:00 committed by GitHub
parent e3d5915566
commit 7de5201291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 117 additions and 73 deletions

View File

@ -55,7 +55,7 @@ public class Classification implements DataFrameAnalysis {
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
private static final String STATE_DOC_ID_INFIX = "_classification_state#";
private static final String NUM_CLASSES = "num_classes";
@ -413,8 +413,8 @@ public class Classification implements DataFrameAnalysis {
}
@Override
public String getStateDocId(String jobId) {
return jobId + STATE_DOC_ID_SUFFIX;
public String getStateDocIdPrefix(String jobId) {
return jobId + STATE_DOC_ID_INFIX;
}
@Override
@ -439,7 +439,7 @@ public class Classification implements DataFrameAnalysis {
}
public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_INFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
}

View File

@ -63,9 +63,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
boolean persistsState();
/**
* Returns the document id for the analysis state
* Returns the document id prefix for the analysis state
*/
String getStateDocId(String jobId);
String getStateDocIdPrefix(String jobId);
/**
* Returns the progress phases the analysis goes through in order

View File

@ -264,7 +264,7 @@ public class OutlierDetection implements DataFrameAnalysis {
}
@Override
public String getStateDocId(String jobId) {
public String getStateDocIdPrefix(String jobId) {
throw new UnsupportedOperationException("Outlier detection does not support state");
}

View File

@ -51,7 +51,7 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
private static final String STATE_DOC_ID_INFIX = "_regression_state#";
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
@ -319,8 +319,8 @@ public class Regression implements DataFrameAnalysis {
}
@Override
public String getStateDocId(String jobId) {
return jobId + STATE_DOC_ID_SUFFIX;
public String getStateDocIdPrefix(String jobId) {
return jobId + STATE_DOC_ID_INFIX;
}
@Override
@ -342,7 +342,7 @@ public class Regression implements DataFrameAnalysis {
}
public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_INFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
}

View File

@ -448,7 +448,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
Classification classification = createRandom();
assertThat(classification.persistsState(), is(true));
String randomId = randomAlphaOfLength(10);
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
assertThat(classification.getStateDocIdPrefix(randomId), equalTo(randomId + "_classification_state#"));
}
public void testExtractJobIdFromStateDoc() {

View File

@ -120,7 +120,7 @@ public class OutlierDetectionTests extends AbstractBWCSerializationTestCase<Outl
public void testGetStateDocId() {
OutlierDetection outlierDetection = createRandom();
assertThat(outlierDetection.persistsState(), is(false));
expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo"));
expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocIdPrefix("foo"));
}
public void testInferenceConfig() {

View File

@ -331,7 +331,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
Regression regression = createRandom();
assertThat(regression.persistsState(), is(true));
String randomId = randomAlphaOfLength(10);
assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
assertThat(regression.getStateDocIdPrefix(randomId), equalTo(randomId + "_regression_state#"));
}
public void testExtractJobIdFromStateDoc() {

View File

@ -649,8 +649,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
new BlackHoleAutodetectProcess(job.getId(), onProcessCrash);
// factor of 1.0 makes renormalization a no-op
normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0);
analyticsProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null;
memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null;
analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
}
NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory,
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));

View File

@ -28,6 +28,7 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
@ -57,8 +58,6 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@ -248,17 +247,46 @@ public class TransportDeleteDataFrameAnalyticsAction
DataFrameAnalyticsConfig config,
TimeValue timeout,
ActionListener<BulkByScrollResponse> listener) {
List<String> ids = new ArrayList<>();
ids.add(StoredProgress.documentId(config.getId()));
if (config.getAnalysis().persistsState()) {
ids.add(config.getAnalysis().getStateDocId(config.getId()));
ActionListener<Boolean> deleteModelStateListener = ActionListener.wrap(
r -> executeDeleteByQuery(
parentTaskClient,
AnomalyDetectorsIndex.jobStateIndexPattern(),
QueryBuilders.idsQuery().addIds(StoredProgress.documentId(config.getId())),
timeout,
listener
)
, listener::onFailure
);
deleteModelState(parentTaskClient, config, timeout, 1, deleteModelStateListener);
}
private void deleteModelState(ParentTaskAssigningClient parentTaskClient,
DataFrameAnalyticsConfig config,
TimeValue timeout,
int docNum,
ActionListener<Boolean> listener) {
if (config.getAnalysis().persistsState() == false) {
listener.onResponse(true);
return;
}
IdsQueryBuilder query = QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocIdPrefix(config.getId()) + docNum);
executeDeleteByQuery(
parentTaskClient,
AnomalyDetectorsIndex.jobStateIndexPattern(),
QueryBuilders.idsQuery().addIds(ids.toArray(new String[0])),
query,
timeout,
listener
ActionListener.wrap(
response -> {
if (response.getDeleted() > 0) {
deleteModelState(parentTaskClient, config, timeout, docNum + 1, listener);
return;
}
listener.onResponse(true);
},
listener::onFailure
)
);
}

View File

@ -5,7 +5,7 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.client.Client;
import org.elasticsearch.xpack.ml.process.NativeProcess;
import java.io.IOException;
@ -41,7 +41,8 @@ public interface AnalyticsProcess<ProcessResult> extends NativeProcess {
/**
* Restores the model state from a previously persisted one
* @param state the state to restore
* @param client the client to use for fetching the state documents
* @param stateDocIdPrefix the prefix of ids of the state documents
*/
void restoreState(BytesReference state) throws IOException;
void restoreState(Client client, String stateDocIdPrefix) throws IOException;
}

View File

@ -5,8 +5,6 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import java.util.concurrent.ExecutorService;
@ -19,12 +17,12 @@ public interface AnalyticsProcessFactory<ProcessResult> {
*
* @param config The data frame analytics config
* @param analyticsProcessConfig The process configuration
* @param state The state document to restore from if there is one available
* @param hasState Whether there is state to restore from
* @param executorService Executor service used to start the async tasks a job needs to operate the analytical process
* @param onProcessCrash Callback to execute if the process stops unexpectedly
* @return The process
*/
AnalyticsProcess<ProcessResult> createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
@Nullable BytesReference state, ExecutorService executorService,
boolean hasState, ExecutorService executorService,
Consumer<String> onProcessCrash);
}

View File

@ -15,13 +15,10 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
@ -139,11 +136,11 @@ public class AnalyticsProcessManager {
}
// Fetch existing model state (if any)
BytesReference state = getModelState(config);
final boolean hasState = hasModelState(config);
boolean isProcessStarted;
try {
isProcessStarted = processContext.startProcess(dataExtractorFactory, task, state);
isProcessStarted = processContext.startProcess(dataExtractorFactory, task, hasState);
} catch (Exception e) {
processContext.stop();
task.setFailed(processContext.getFailureReason() == null ?
@ -153,7 +150,7 @@ public class AnalyticsProcessManager {
if (isProcessStarted) {
executorServiceForProcess.execute(() -> processContext.resultProcessor.get().process(processContext.process.get()));
executorServiceForProcess.execute(() -> processData(task, processContext, state));
executorServiceForProcess.execute(() -> processData(task, processContext, hasState));
} else {
processContextByAllocation.remove(task.getAllocationId());
auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_ANALYSIS);
@ -162,23 +159,22 @@ public class AnalyticsProcessManager {
});
}
@Nullable
private BytesReference getModelState(DataFrameAnalyticsConfig config) {
private boolean hasModelState(DataFrameAnalyticsConfig config) {
if (config.getAnalysis().persistsState() == false) {
return null;
return false;
}
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
.setFetchSource(false)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocIdPrefix(config.getId()) + "1"))
.get();
SearchHit[] hits = searchResponse.getHits().getHits();
return hits.length == 0 ? null : hits[0].getSourceRef();
return searchResponse.getHits().getHits().length == 1;
}
}
private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, BytesReference state) {
private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, boolean hasState) {
LOGGER.info("[{}] Started loading data", processContext.config.getId());
auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA));
@ -193,7 +189,7 @@ public class AnalyticsProcessManager {
process.writeEndOfDataMessage();
process.flushStream();
restoreState(task, config, state, process);
restoreState(task, config, process, hasState);
LOGGER.info("[{}] Started analyzing", processContext.config.getId());
auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_ANALYZING));
@ -297,14 +293,14 @@ public class AnalyticsProcessManager {
process.writeRecord(headerRecord);
}
private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state,
AnalyticsProcess<AnalyticsResult> process) {
private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, AnalyticsProcess<AnalyticsResult> process,
boolean hasState) {
if (config.getAnalysis().persistsState() == false) {
LOGGER.debug("[{}] Analysis does not support state", config.getId());
return;
}
if (state == null) {
if (hasState == false) {
LOGGER.debug("[{}] No model state available to restore", config.getId());
return;
}
@ -313,7 +309,7 @@ public class AnalyticsProcessManager {
auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE);
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
process.restoreState(state);
process.restoreState(client, config.getAnalysis().getStateDocIdPrefix(config.getId()));
} catch (Exception e) {
LOGGER.error(new ParameterizedMessage("[{}] Failed to restore state", process.getConfig().jobId()), e);
task.setFailed(ExceptionsHelper.serverError("Failed to restore state: " + e.getMessage()));
@ -321,9 +317,9 @@ public class AnalyticsProcessManager {
}
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
AnalyticsProcess<AnalyticsResult> process =
processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
AnalyticsProcessConfig analyticsProcessConfig, boolean hasState) {
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(
config, analyticsProcessConfig, hasState, executorServiceForProcess, onProcessCrash(task));
if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
}
@ -467,7 +463,7 @@ public class AnalyticsProcessManager {
*/
synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory,
DataFrameAnalyticsTask task,
@Nullable BytesReference state) {
boolean hasState) {
if (task.isStopping()) {
// The job was stopped before we started the process so no need to start it
return false;
@ -483,7 +479,7 @@ public class AnalyticsProcessManager {
LOGGER.info("[{}] no data found to analyze. Will not start analytics native process.", config.getId());
return false;
}
process.set(createProcess(task, config, analyticsProcessConfig, state));
process.set(createProcess(task, config, analyticsProcessConfig, hasState));
resultProcessor.set(createResultProcessor(task, dataExtractorFactory));
return true;
}

View File

@ -85,7 +85,7 @@ public class MemoryUsageEstimationProcessManager {
processFactory.createAnalyticsProcess(
config,
processConfig,
null,
false,
executorServiceForProcess,
// The handler passed here will never be called as AbstractNativeProcess.detectCrash method returns early when
// (processInStream == null) which is the case for MemoryUsageEstimationProcess.

View File

@ -5,8 +5,15 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper;
@ -21,6 +28,8 @@ import java.util.function.Consumer;
public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<AnalyticsResult> {
private static final Logger logger = LogManager.getLogger(NativeAnalyticsProcess.class);
private static final String NAME = "analytics";
private final AnalyticsProcessConfig config;
@ -55,10 +64,26 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
}
@Override
public void restoreState(BytesReference state) throws IOException {
Objects.requireNonNull(state);
public void restoreState(Client client, String stateDocIdPrefix) throws IOException {
Objects.requireNonNull(stateDocIdPrefix);
try (OutputStream restoreStream = processRestoreStream()) {
StateToProcessWriterHelper.writeStateToStream(state, restoreStream);
int docNum = 0;
while (true) {
if (isProcessKilled()) {
return;
}
// We fetch the documents one at a time because all together they can amount to too much memory
SearchResponse stateResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(stateDocIdPrefix + ++docNum)).get();
if (stateResponse.getHits().getHits().length == 0) {
break;
}
SearchHit stateDoc = stateResponse.getHits().getAt(0);
logger.debug(() -> new ParameterizedMessage("[{}] Restoring state document [{}]", config.jobId(), stateDoc.getId()));
StateToProcessWriterHelper.writeStateToStream(stateDoc.getSourceRef(), restoreStream);
}
}
}
}

View File

@ -8,8 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -70,12 +68,12 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
@Override
public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
@Nullable BytesReference state, ExecutorService executorService,
boolean hasState, ExecutorService executorService,
Consumer<String> onProcessCrash) {
String jobId = config.getId();
List<Path> filesToDelete = new ArrayList<>();
ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId,
false, true, true, state != null, config.getAnalysis().persistsState());
false, true, true, hasState, config.getAnalysis().persistsState());
// The extra 2 are for the checksum and the control field
int numberOfFields = analyticsProcessConfig.cols() + 2;

View File

@ -5,7 +5,7 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
@ -32,7 +32,7 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
}
@Override
public void restoreState(BytesReference state) {
public void restoreState(Client client, String stateDocIdPrefix) {
throw new UnsupportedOperationException();
}
}

View File

@ -8,8 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.internal.io.IOUtils;
@ -60,7 +58,7 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce
public NativeMemoryUsageEstimationProcess createAnalyticsProcess(
DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig,
@Nullable BytesReference state,
boolean hasState,
ExecutorService executorService,
Consumer<String> onProcessCrash) {
List<Path> filesToDelete = new ArrayList<>();

View File

@ -92,7 +92,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
when(process.isProcessAlive()).thenReturn(true);
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
processFactory = mock(AnalyticsProcessFactory.class);
when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
when(processFactory.createAnalyticsProcess(any(), any(), anyBoolean(), any(), any())).thenReturn(process);
auditor = mock(DataFrameAnalyticsAuditor.class);
trainedModelProvider = mock(TrainedModelProvider.class);
@ -226,7 +226,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
processContext.stop();
assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(false));
assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(false));
InOrder inOrder = inOrder(dataExtractor, process, task);
inOrder.verify(task).isStopping();
@ -237,7 +237,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(false));
assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(false));
InOrder inOrder = inOrder(dataExtractor, process, task);
inOrder.verify(task).isStopping();
@ -248,7 +248,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
public void testProcessContext_StartAndStop() throws Exception {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(true));
assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(true));
processContext.stop();
InOrder inOrder = inOrder(dataExtractor, process, task);

View File

@ -66,7 +66,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
process = mock(AnalyticsProcess.class);
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
processFactory = mock(AnalyticsProcessFactory.class);
when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
when(processFactory.createAnalyticsProcess(any(), any(), anyBoolean(), any(), any())).thenReturn(process);
dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);