[7.x] Make AnalyticsProcessManager class more robust (#49282) (#49356)

This commit is contained in:
Przemysław Witek 2019-11-20 10:08:16 +01:00 committed by GitHub
parent f753fa2265
commit 9c0ec7ce23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 296 additions and 58 deletions

View File

@ -11,6 +11,7 @@ import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -180,7 +181,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Finished analysis"); "Finished analysis");
} }
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/47612")
public void testStopAndRestart() throws Exception { public void testStopAndRestart() throws Exception {
initialize("regression_stop_and_restart"); initialize("regression_stop_and_restart");
indexData(sourceIndex, 350, 0); indexData(sourceIndex, 350, 0);
@ -197,7 +197,11 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED. // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
assertBusy(() -> { assertBusy(() -> {
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState(); DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING), assertThat(
state,
is(anyOf(
equalTo(DataFrameAnalyticsState.REINDEXING),
equalTo(DataFrameAnalyticsState.ANALYZING),
equalTo(DataFrameAnalyticsState.STOPPED)))); equalTo(DataFrameAnalyticsState.STOPPED))));
}); });
stopAnalytics(jobId); stopAnalytics(jobId);
@ -214,7 +218,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
} }
} }
waitUntilAnalyticsIsStopped(jobId); waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) { for (SearchHit hit : sourceData.getHits()) {

View File

@ -10,7 +10,6 @@ import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
@ -54,7 +53,8 @@ public class AnalyticsProcessManager {
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class); private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
private final Client client; private final Client client;
private final ThreadPool threadPool; private final ExecutorService executorServiceForJob;
private final ExecutorService executorServiceForProcess;
private final AnalyticsProcessFactory<AnalyticsResult> processFactory; private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>(); private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor; private final DataFrameAnalyticsAuditor auditor;
@ -65,8 +65,25 @@ public class AnalyticsProcessManager {
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory, AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) { TrainedModelProvider trainedModelProvider) {
this(
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
analyticsProcessFactory,
auditor,
trainedModelProvider);
}
// Visible for testing
public AnalyticsProcessManager(Client client,
ExecutorService executorServiceForJob,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool); this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor); this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
@ -74,31 +91,33 @@ public class AnalyticsProcessManager {
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
Consumer<Exception> finishHandler) { Consumer<Exception> finishHandler) {
threadPool.generic().execute(() -> { executorServiceForJob.execute(() -> {
ProcessContext processContext = new ProcessContext(config.getId());
synchronized (this) {
if (task.isStopping()) { if (task.isStopping()) {
// The task was requested to stop before we created the process context // The task was requested to stop before we created the process context
finishHandler.accept(null); finishHandler.accept(null);
return; return;
} }
// First we refresh the dest index to ensure data is searchable
refreshDest(config);
ProcessContext processContext = new ProcessContext(config.getId());
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id finishHandler.accept(
+ "] Could not create process as one already exists")); ExceptionsHelper.serverError("[" + config.getId() + "] Could not create process as one already exists"));
return; return;
} }
}
// Refresh the dest index to ensure data is searchable
refreshDest(config);
// Fetch existing model state (if any)
BytesReference state = getModelState(config); BytesReference state = getModelState(config);
if (processContext.startProcess(dataExtractorFactory, config, task, state)) { if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); executorServiceForProcess.execute(() -> processResults(processContext));
executorService.execute(() -> processResults(processContext)); executorServiceForProcess.execute(() -> processData(task, config, processContext.dataExtractor,
executorService.execute(() -> processData(task, config, processContext.dataExtractor,
processContext.process, processContext.resultProcessor, finishHandler, state)); processContext.process, processContext.resultProcessor, finishHandler, state));
} else { } else {
processContextByAllocation.remove(task.getAllocationId());
finishHandler.accept(null); finishHandler.accept(null);
} }
}); });
@ -111,8 +130,6 @@ public class AnalyticsProcessManager {
} }
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())));
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1) .setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId()))) .setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
@ -246,9 +263,8 @@ public class AnalyticsProcessManager {
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) { AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); AnalyticsProcess<AnalyticsResult> process =
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
executorService, onProcessCrash(task));
if (process.isProcessAlive() == false) { if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start data frame analytics process"); throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
} }
@ -285,17 +301,22 @@ public class AnalyticsProcessManager {
} }
} }
public void stop(DataFrameAnalyticsTask task) { public synchronized void stop(DataFrameAnalyticsTask task) {
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
if (processContext != null) { if (processContext != null) {
LOGGER.debug("[{}] Stopping process", task.getParams().getId() ); LOGGER.debug("[{}] Stopping process", task.getParams().getId());
processContext.stop(); processContext.stop();
} else { } else {
LOGGER.debug("[{}] No process context to stop", task.getParams().getId() ); LOGGER.debug("[{}] No process context to stop", task.getParams().getId());
task.markAsCompleted(); task.markAsCompleted();
} }
} }
// Visible for testing
int getProcessContextCount() {
return processContextByAllocation.size();
}
class ProcessContext { class ProcessContext {
private final String id; private final String id;
@ -309,31 +330,26 @@ public class AnalyticsProcessManager {
this.id = Objects.requireNonNull(id); this.id = Objects.requireNonNull(id);
} }
public String getId() { synchronized String getFailureReason() {
return id; return failureReason;
} }
public boolean isProcessKilled() { synchronized void setFailureReason(String failureReason) {
return processKilled;
}
private synchronized void setFailureReason(String failureReason) {
// Only set the new reason if there isn't one already as we want to keep the first reason // Only set the new reason if there isn't one already as we want to keep the first reason
if (failureReason != null) { if (this.failureReason == null && failureReason != null) {
this.failureReason = failureReason; this.failureReason = failureReason;
} }
} }
private String getFailureReason() { synchronized void stop() {
return failureReason;
}
public synchronized void stop() {
LOGGER.debug("[{}] Stopping process", id); LOGGER.debug("[{}] Stopping process", id);
processKilled = true; processKilled = true;
if (dataExtractor != null) { if (dataExtractor != null) {
dataExtractor.cancel(); dataExtractor.cancel();
} }
if (resultProcessor != null) {
resultProcessor.cancel();
}
if (process != null) { if (process != null) {
try { try {
process.kill(); process.kill();
@ -346,7 +362,7 @@ public class AnalyticsProcessManager {
/** /**
* @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime * @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
*/ */
private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config, synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
DataFrameAnalyticsTask task, @Nullable BytesReference state) { DataFrameAnalyticsTask task, @Nullable BytesReference state) {
if (processKilled) { if (processKilled) {
// The job was stopped before we started the process so no need to start it // The job was stopped before we started the process so no need to start it
@ -365,8 +381,8 @@ public class AnalyticsProcessManager {
process = createProcess(task, config, analyticsProcessConfig, state); process = createProcess(task, config, analyticsProcessConfig, state);
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
dataExtractorFactory.newExtractor(true)); dataExtractorFactory.newExtractor(true));
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(), resultProcessor = new AnalyticsResultProcessor(
trainedModelProvider, auditor, dataExtractor.getFieldNames()); config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.getFieldNames());
return true; return true;
} }

View File

@ -31,7 +31,6 @@ import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
public class AnalyticsResultProcessor { public class AnalyticsResultProcessor {
@ -39,21 +38,19 @@ public class AnalyticsResultProcessor {
private final DataFrameAnalyticsConfig analytics; private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner; private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final Supplier<Boolean> isProcessKilled;
private final ProgressTracker progressTracker; private final ProgressTracker progressTracker;
private final TrainedModelProvider trainedModelProvider; private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor; private final DataFrameAnalyticsAuditor auditor;
private final List<String> fieldNames; private final List<String> fieldNames;
private final CountDownLatch completionLatch = new CountDownLatch(1); private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure; private volatile String failure;
private volatile boolean isCancelled;
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker, ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
List<String> fieldNames) {
this.analytics = Objects.requireNonNull(analytics); this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
this.progressTracker = Objects.requireNonNull(progressTracker); this.progressTracker = Objects.requireNonNull(progressTracker);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor); this.auditor = Objects.requireNonNull(auditor);
@ -74,6 +71,10 @@ public class AnalyticsResultProcessor {
} }
} }
public void cancel() {
isCancelled = true;
}
public void process(AnalyticsProcess<AnalyticsResult> process) { public void process(AnalyticsProcess<AnalyticsResult> process) {
long totalRows = process.getConfig().rows(); long totalRows = process.getConfig().rows();
long processedRows = 0; long processedRows = 0;
@ -82,6 +83,9 @@ public class AnalyticsResultProcessor {
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults(); Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
while (iterator.hasNext()) { while (iterator.hasNext()) {
if (isCancelled) {
break;
}
AnalyticsResult result = iterator.next(); AnalyticsResult result = iterator.next();
processResult(result, resultsJoiner); processResult(result, resultsJoiner);
if (result.getRowResults() != null) { if (result.getRowResults() != null) {
@ -89,13 +93,13 @@ public class AnalyticsResultProcessor {
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
} }
} }
if (isProcessKilled.get() == false) { if (isCancelled == false) {
// This means we completed successfully so we need to set the progress to 100. // This means we completed successfully so we need to set the progress to 100.
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows. // This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
progressTracker.writingResultsPercent.set(100); progressTracker.writingResultsPercent.set(100);
} }
} catch (Exception e) { } catch (Exception e) {
if (isProcessKilled.get()) { if (isCancelled) {
// No need to log error as it's due to stopping // No need to log error as it's due to stopping
} else { } else {
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e); LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);

View File

@ -0,0 +1,214 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Test for the basic functionality of {@link AnalyticsProcessManager} and {@link AnalyticsProcessManager.ProcessContext}.
* This test does not spawn any threads. Instead:
* - job is run on a current thread (using {@code DirectExecutorService})
* - {@code processData} and {@code processResults} methods are not run at all (using mock executor)
*/
public class AnalyticsProcessManagerTests extends ESTestCase {
private static final long TASK_ALLOCATION_ID = 123;
private static final String CONFIG_ID = "config-id";
private static final int NUM_ROWS = 100;
private static final int NUM_COLS = 4;
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null);
private Client client;
private DataFrameAnalyticsAuditor auditor;
private TrainedModelProvider trainedModelProvider;
private ExecutorService executorServiceForJob;
private ExecutorService executorServiceForProcess;
private AnalyticsProcess<AnalyticsResult> process;
private AnalyticsProcessFactory<AnalyticsResult> processFactory;
private DataFrameAnalyticsTask task;
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
private DataFrameDataExtractorFactory dataExtractorFactory;
private DataFrameDataExtractor dataExtractor;
private Consumer<Exception> finishHandler;
private ArgumentCaptor<Exception> exceptionCaptor;
private AnalyticsProcessManager processManager;
@SuppressWarnings("unchecked")
@Before
public void setUpMocks() {
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
when(client.execute(any(), any())).thenReturn(mock(ActionFuture.class));
executorServiceForJob = EsExecutors.newDirectExecutorService();
executorServiceForProcess = mock(ExecutorService.class);
process = mock(AnalyticsProcess.class);
when(process.isProcessAlive()).thenReturn(true);
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
processFactory = mock(AnalyticsProcessFactory.class);
when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
auditor = mock(DataFrameAnalyticsAuditor.class);
trainedModelProvider = mock(TrainedModelProvider.class);
task = mock(DataFrameAnalyticsTask.class);
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class));
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
finishHandler = mock(Consumer.class);
exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
processManager = new AnalyticsProcessManager(
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
}
public void testRunJob_TaskIsStopping() {
when(task.isStopping()).thenReturn(true);
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
assertThat(processManager.getProcessContextCount(), equalTo(0));
verify(finishHandler).accept(null);
verifyNoMoreInteractions(finishHandler);
}
public void testRunJob_ProcessContextAlreadyExists() {
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
assertThat(processManager.getProcessContextCount(), equalTo(1));
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
assertThat(processManager.getProcessContextCount(), equalTo(1));
verify(finishHandler).accept(exceptionCaptor.capture());
verifyNoMoreInteractions(finishHandler);
Exception e = exceptionCaptor.getValue();
assertThat(e.getMessage(), equalTo("[config-id] Could not create process as one already exists"));
}
public void testRunJob_EmptyDataFrame() {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
assertThat(processManager.getProcessContextCount(), equalTo(0)); // Make sure the process context did not leak
InOrder inOrder = inOrder(dataExtractor, executorServiceForProcess, process, finishHandler);
inOrder.verify(dataExtractor).collectDataSummary();
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(finishHandler).accept(null);
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, finishHandler);
}
public void testRunJob_Ok() {
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
assertThat(processManager.getProcessContextCount(), equalTo(1));
InOrder inOrder = inOrder(dataExtractor, executorServiceForProcess, process, finishHandler);
inOrder.verify(dataExtractor).collectDataSummary();
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(process).isProcessAlive();
inOrder.verify(dataExtractor).getFieldNames();
inOrder.verify(executorServiceForProcess, times(2)).execute(any()); // 'processData' and 'processResults' threads
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, finishHandler);
}
public void testProcessContext_GetSetFailureReason() {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
assertThat(processContext.getFailureReason(), is(nullValue()));
processContext.setFailureReason("reason1");
assertThat(processContext.getFailureReason(), equalTo("reason1"));
processContext.setFailureReason(null);
assertThat(processContext.getFailureReason(), equalTo("reason1"));
processContext.setFailureReason("reason2");
assertThat(processContext.getFailureReason(), equalTo("reason1"));
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
}
public void testProcessContext_StartProcess_ProcessAlreadyKilled() {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
processContext.stop();
assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(false));
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
}
public void testProcessContext_StartProcess_EmptyDataFrame() {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(false));
InOrder inOrder = inOrder(dataExtractor, process, finishHandler);
inOrder.verify(dataExtractor).collectDataSummary();
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
}
public void testProcessContext_StartAndStop() throws Exception {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(true));
processContext.stop();
InOrder inOrder = inOrder(dataExtractor, process, finishHandler);
// startProcess
inOrder.verify(dataExtractor).collectDataSummary();
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(process).isProcessAlive();
inOrder.verify(dataExtractor).getFieldNames();
// stop
inOrder.verify(dataExtractor).cancel();
inOrder.verify(process).kill();
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
}
public void testProcessContext_Stop() {
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
processContext.stop();
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
}
}

View File

@ -200,7 +200,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
} }
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) { private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider, return new AnalyticsResultProcessor(
auditor, fieldNames); analyticsConfig, dataFrameRowsJoiner, progressTracker, trainedModelProvider, auditor, fieldNames);
} }
} }