parent
f753fa2265
commit
9c0ec7ce23
|
@ -11,6 +11,7 @@ import org.elasticsearch.action.get.GetResponse;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.index.query.QueryBuilders;
|
import org.elasticsearch.index.query.QueryBuilders;
|
||||||
import org.elasticsearch.search.SearchHit;
|
import org.elasticsearch.search.SearchHit;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
|
@ -180,7 +181,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Finished analysis");
|
"Finished analysis");
|
||||||
}
|
}
|
||||||
|
|
||||||
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/47612")
|
|
||||||
public void testStopAndRestart() throws Exception {
|
public void testStopAndRestart() throws Exception {
|
||||||
initialize("regression_stop_and_restart");
|
initialize("regression_stop_and_restart");
|
||||||
indexData(sourceIndex, 350, 0);
|
indexData(sourceIndex, 350, 0);
|
||||||
|
@ -197,8 +197,12 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
|
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
|
||||||
assertBusy(() -> {
|
assertBusy(() -> {
|
||||||
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
|
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
|
||||||
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
|
assertThat(
|
||||||
equalTo(DataFrameAnalyticsState.STOPPED))));
|
state,
|
||||||
|
is(anyOf(
|
||||||
|
equalTo(DataFrameAnalyticsState.REINDEXING),
|
||||||
|
equalTo(DataFrameAnalyticsState.ANALYZING),
|
||||||
|
equalTo(DataFrameAnalyticsState.STOPPED))));
|
||||||
});
|
});
|
||||||
stopAnalytics(jobId);
|
stopAnalytics(jobId);
|
||||||
waitUntilAnalyticsIsStopped(jobId);
|
waitUntilAnalyticsIsStopped(jobId);
|
||||||
|
@ -214,7 +218,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
waitUntilAnalyticsIsStopped(jobId);
|
waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));
|
||||||
|
|
||||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||||
for (SearchHit hit : sourceData.getHits()) {
|
for (SearchHit hit : sourceData.getHits()) {
|
||||||
|
|
|
@ -10,7 +10,6 @@ import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||||
import org.elasticsearch.action.search.SearchRequest;
|
|
||||||
import org.elasticsearch.action.search.SearchResponse;
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.common.Nullable;
|
import org.elasticsearch.common.Nullable;
|
||||||
|
@ -54,7 +53,8 @@ public class AnalyticsProcessManager {
|
||||||
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
|
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
|
||||||
|
|
||||||
private final Client client;
|
private final Client client;
|
||||||
private final ThreadPool threadPool;
|
private final ExecutorService executorServiceForJob;
|
||||||
|
private final ExecutorService executorServiceForProcess;
|
||||||
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
|
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
|
||||||
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
||||||
private final DataFrameAnalyticsAuditor auditor;
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
|
@ -65,8 +65,25 @@ public class AnalyticsProcessManager {
|
||||||
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||||
DataFrameAnalyticsAuditor auditor,
|
DataFrameAnalyticsAuditor auditor,
|
||||||
TrainedModelProvider trainedModelProvider) {
|
TrainedModelProvider trainedModelProvider) {
|
||||||
|
this(
|
||||||
|
client,
|
||||||
|
threadPool.generic(),
|
||||||
|
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
|
||||||
|
analyticsProcessFactory,
|
||||||
|
auditor,
|
||||||
|
trainedModelProvider);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Visible for testing
|
||||||
|
public AnalyticsProcessManager(Client client,
|
||||||
|
ExecutorService executorServiceForJob,
|
||||||
|
ExecutorService executorServiceForProcess,
|
||||||
|
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||||
|
DataFrameAnalyticsAuditor auditor,
|
||||||
|
TrainedModelProvider trainedModelProvider) {
|
||||||
this.client = Objects.requireNonNull(client);
|
this.client = Objects.requireNonNull(client);
|
||||||
this.threadPool = Objects.requireNonNull(threadPool);
|
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
|
||||||
|
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
|
||||||
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
|
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
|
||||||
this.auditor = Objects.requireNonNull(auditor);
|
this.auditor = Objects.requireNonNull(auditor);
|
||||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||||
|
@ -74,31 +91,33 @@ public class AnalyticsProcessManager {
|
||||||
|
|
||||||
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
|
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
|
||||||
Consumer<Exception> finishHandler) {
|
Consumer<Exception> finishHandler) {
|
||||||
threadPool.generic().execute(() -> {
|
executorServiceForJob.execute(() -> {
|
||||||
if (task.isStopping()) {
|
ProcessContext processContext = new ProcessContext(config.getId());
|
||||||
// The task was requested to stop before we created the process context
|
synchronized (this) {
|
||||||
finishHandler.accept(null);
|
if (task.isStopping()) {
|
||||||
return;
|
// The task was requested to stop before we created the process context
|
||||||
|
finishHandler.accept(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
|
||||||
|
finishHandler.accept(
|
||||||
|
ExceptionsHelper.serverError("[" + config.getId() + "] Could not create process as one already exists"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// First we refresh the dest index to ensure data is searchable
|
// Refresh the dest index to ensure data is searchable
|
||||||
refreshDest(config);
|
refreshDest(config);
|
||||||
|
|
||||||
ProcessContext processContext = new ProcessContext(config.getId());
|
// Fetch existing model state (if any)
|
||||||
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
|
|
||||||
finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id
|
|
||||||
+ "] Could not create process as one already exists"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
BytesReference state = getModelState(config);
|
BytesReference state = getModelState(config);
|
||||||
|
|
||||||
if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
|
if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
|
||||||
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
|
executorServiceForProcess.execute(() -> processResults(processContext));
|
||||||
executorService.execute(() -> processResults(processContext));
|
executorServiceForProcess.execute(() -> processData(task, config, processContext.dataExtractor,
|
||||||
executorService.execute(() -> processData(task, config, processContext.dataExtractor,
|
|
||||||
processContext.process, processContext.resultProcessor, finishHandler, state));
|
processContext.process, processContext.resultProcessor, finishHandler, state));
|
||||||
} else {
|
} else {
|
||||||
|
processContextByAllocation.remove(task.getAllocationId());
|
||||||
finishHandler.accept(null);
|
finishHandler.accept(null);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -111,8 +130,6 @@ public class AnalyticsProcessManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
|
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
|
||||||
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
|
|
||||||
searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())));
|
|
||||||
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||||
.setSize(1)
|
.setSize(1)
|
||||||
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
|
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
|
||||||
|
@ -246,9 +263,8 @@ public class AnalyticsProcessManager {
|
||||||
|
|
||||||
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
|
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
|
||||||
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
|
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
|
||||||
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
|
AnalyticsProcess<AnalyticsResult> process =
|
||||||
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state,
|
processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
|
||||||
executorService, onProcessCrash(task));
|
|
||||||
if (process.isProcessAlive() == false) {
|
if (process.isProcessAlive() == false) {
|
||||||
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
|
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
|
||||||
}
|
}
|
||||||
|
@ -285,17 +301,22 @@ public class AnalyticsProcessManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void stop(DataFrameAnalyticsTask task) {
|
public synchronized void stop(DataFrameAnalyticsTask task) {
|
||||||
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
|
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
|
||||||
if (processContext != null) {
|
if (processContext != null) {
|
||||||
LOGGER.debug("[{}] Stopping process", task.getParams().getId() );
|
LOGGER.debug("[{}] Stopping process", task.getParams().getId());
|
||||||
processContext.stop();
|
processContext.stop();
|
||||||
} else {
|
} else {
|
||||||
LOGGER.debug("[{}] No process context to stop", task.getParams().getId() );
|
LOGGER.debug("[{}] No process context to stop", task.getParams().getId());
|
||||||
task.markAsCompleted();
|
task.markAsCompleted();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visible for testing
|
||||||
|
int getProcessContextCount() {
|
||||||
|
return processContextByAllocation.size();
|
||||||
|
}
|
||||||
|
|
||||||
class ProcessContext {
|
class ProcessContext {
|
||||||
|
|
||||||
private final String id;
|
private final String id;
|
||||||
|
@ -309,31 +330,26 @@ public class AnalyticsProcessManager {
|
||||||
this.id = Objects.requireNonNull(id);
|
this.id = Objects.requireNonNull(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getId() {
|
synchronized String getFailureReason() {
|
||||||
return id;
|
return failureReason;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isProcessKilled() {
|
synchronized void setFailureReason(String failureReason) {
|
||||||
return processKilled;
|
|
||||||
}
|
|
||||||
|
|
||||||
private synchronized void setFailureReason(String failureReason) {
|
|
||||||
// Only set the new reason if there isn't one already as we want to keep the first reason
|
// Only set the new reason if there isn't one already as we want to keep the first reason
|
||||||
if (failureReason != null) {
|
if (this.failureReason == null && failureReason != null) {
|
||||||
this.failureReason = failureReason;
|
this.failureReason = failureReason;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getFailureReason() {
|
synchronized void stop() {
|
||||||
return failureReason;
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void stop() {
|
|
||||||
LOGGER.debug("[{}] Stopping process", id);
|
LOGGER.debug("[{}] Stopping process", id);
|
||||||
processKilled = true;
|
processKilled = true;
|
||||||
if (dataExtractor != null) {
|
if (dataExtractor != null) {
|
||||||
dataExtractor.cancel();
|
dataExtractor.cancel();
|
||||||
}
|
}
|
||||||
|
if (resultProcessor != null) {
|
||||||
|
resultProcessor.cancel();
|
||||||
|
}
|
||||||
if (process != null) {
|
if (process != null) {
|
||||||
try {
|
try {
|
||||||
process.kill();
|
process.kill();
|
||||||
|
@ -346,8 +362,8 @@ public class AnalyticsProcessManager {
|
||||||
/**
|
/**
|
||||||
* @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
|
* @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
|
||||||
*/
|
*/
|
||||||
private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
|
synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
|
||||||
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
|
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
|
||||||
if (processKilled) {
|
if (processKilled) {
|
||||||
// The job was stopped before we started the process so no need to start it
|
// The job was stopped before we started the process so no need to start it
|
||||||
return false;
|
return false;
|
||||||
|
@ -365,8 +381,8 @@ public class AnalyticsProcessManager {
|
||||||
process = createProcess(task, config, analyticsProcessConfig, state);
|
process = createProcess(task, config, analyticsProcessConfig, state);
|
||||||
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
|
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
|
||||||
dataExtractorFactory.newExtractor(true));
|
dataExtractorFactory.newExtractor(true));
|
||||||
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
|
resultProcessor = new AnalyticsResultProcessor(
|
||||||
trainedModelProvider, auditor, dataExtractor.getFieldNames());
|
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.getFieldNames());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,6 @@ import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Supplier;
|
|
||||||
|
|
||||||
public class AnalyticsResultProcessor {
|
public class AnalyticsResultProcessor {
|
||||||
|
|
||||||
|
@ -39,21 +38,19 @@ public class AnalyticsResultProcessor {
|
||||||
|
|
||||||
private final DataFrameAnalyticsConfig analytics;
|
private final DataFrameAnalyticsConfig analytics;
|
||||||
private final DataFrameRowsJoiner dataFrameRowsJoiner;
|
private final DataFrameRowsJoiner dataFrameRowsJoiner;
|
||||||
private final Supplier<Boolean> isProcessKilled;
|
|
||||||
private final ProgressTracker progressTracker;
|
private final ProgressTracker progressTracker;
|
||||||
private final TrainedModelProvider trainedModelProvider;
|
private final TrainedModelProvider trainedModelProvider;
|
||||||
private final DataFrameAnalyticsAuditor auditor;
|
private final DataFrameAnalyticsAuditor auditor;
|
||||||
private final List<String> fieldNames;
|
private final List<String> fieldNames;
|
||||||
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
private final CountDownLatch completionLatch = new CountDownLatch(1);
|
||||||
private volatile String failure;
|
private volatile String failure;
|
||||||
|
private volatile boolean isCancelled;
|
||||||
|
|
||||||
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
|
||||||
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
|
ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
|
||||||
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor,
|
DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
|
||||||
List<String> fieldNames) {
|
|
||||||
this.analytics = Objects.requireNonNull(analytics);
|
this.analytics = Objects.requireNonNull(analytics);
|
||||||
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
|
||||||
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
|
|
||||||
this.progressTracker = Objects.requireNonNull(progressTracker);
|
this.progressTracker = Objects.requireNonNull(progressTracker);
|
||||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||||
this.auditor = Objects.requireNonNull(auditor);
|
this.auditor = Objects.requireNonNull(auditor);
|
||||||
|
@ -74,6 +71,10 @@ public class AnalyticsResultProcessor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void cancel() {
|
||||||
|
isCancelled = true;
|
||||||
|
}
|
||||||
|
|
||||||
public void process(AnalyticsProcess<AnalyticsResult> process) {
|
public void process(AnalyticsProcess<AnalyticsResult> process) {
|
||||||
long totalRows = process.getConfig().rows();
|
long totalRows = process.getConfig().rows();
|
||||||
long processedRows = 0;
|
long processedRows = 0;
|
||||||
|
@ -82,6 +83,9 @@ public class AnalyticsResultProcessor {
|
||||||
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
|
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
|
||||||
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
|
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
|
||||||
while (iterator.hasNext()) {
|
while (iterator.hasNext()) {
|
||||||
|
if (isCancelled) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
AnalyticsResult result = iterator.next();
|
AnalyticsResult result = iterator.next();
|
||||||
processResult(result, resultsJoiner);
|
processResult(result, resultsJoiner);
|
||||||
if (result.getRowResults() != null) {
|
if (result.getRowResults() != null) {
|
||||||
|
@ -89,13 +93,13 @@ public class AnalyticsResultProcessor {
|
||||||
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
|
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (isProcessKilled.get() == false) {
|
if (isCancelled == false) {
|
||||||
// This means we completed successfully so we need to set the progress to 100.
|
// This means we completed successfully so we need to set the progress to 100.
|
||||||
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
|
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
|
||||||
progressTracker.writingResultsPercent.set(100);
|
progressTracker.writingResultsPercent.set(100);
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
if (isProcessKilled.get()) {
|
if (isCancelled) {
|
||||||
// No need to log error as it's due to stopping
|
// No need to log error as it's due to stopping
|
||||||
} else {
|
} else {
|
||||||
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
|
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
|
||||||
|
|
|
@ -0,0 +1,214 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||||
|
|
||||||
|
import org.elasticsearch.action.ActionFuture;
|
||||||
|
import org.elasticsearch.client.Client;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.InOrder;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
import static org.mockito.Matchers.any;
|
||||||
|
import static org.mockito.Matchers.anyBoolean;
|
||||||
|
import static org.mockito.Mockito.inOrder;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test for the basic functionality of {@link AnalyticsProcessManager} and {@link AnalyticsProcessManager.ProcessContext}.
|
||||||
|
* This test does not spawn any threads. Instead:
|
||||||
|
* - job is run on a current thread (using {@code DirectExecutorService})
|
||||||
|
* - {@code processData} and {@code processResults} methods are not run at all (using mock executor)
|
||||||
|
*/
|
||||||
|
public class AnalyticsProcessManagerTests extends ESTestCase {
|
||||||
|
|
||||||
|
private static final long TASK_ALLOCATION_ID = 123;
|
||||||
|
private static final String CONFIG_ID = "config-id";
|
||||||
|
private static final int NUM_ROWS = 100;
|
||||||
|
private static final int NUM_COLS = 4;
|
||||||
|
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null);
|
||||||
|
|
||||||
|
private Client client;
|
||||||
|
private DataFrameAnalyticsAuditor auditor;
|
||||||
|
private TrainedModelProvider trainedModelProvider;
|
||||||
|
private ExecutorService executorServiceForJob;
|
||||||
|
private ExecutorService executorServiceForProcess;
|
||||||
|
private AnalyticsProcess<AnalyticsResult> process;
|
||||||
|
private AnalyticsProcessFactory<AnalyticsResult> processFactory;
|
||||||
|
private DataFrameAnalyticsTask task;
|
||||||
|
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
|
||||||
|
private DataFrameDataExtractorFactory dataExtractorFactory;
|
||||||
|
private DataFrameDataExtractor dataExtractor;
|
||||||
|
private Consumer<Exception> finishHandler;
|
||||||
|
private ArgumentCaptor<Exception> exceptionCaptor;
|
||||||
|
private AnalyticsProcessManager processManager;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
@Before
|
||||||
|
public void setUpMocks() {
|
||||||
|
ThreadPool threadPool = mock(ThreadPool.class);
|
||||||
|
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||||
|
client = mock(Client.class);
|
||||||
|
when(client.threadPool()).thenReturn(threadPool);
|
||||||
|
when(client.execute(any(), any())).thenReturn(mock(ActionFuture.class));
|
||||||
|
executorServiceForJob = EsExecutors.newDirectExecutorService();
|
||||||
|
executorServiceForProcess = mock(ExecutorService.class);
|
||||||
|
process = mock(AnalyticsProcess.class);
|
||||||
|
when(process.isProcessAlive()).thenReturn(true);
|
||||||
|
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
|
||||||
|
processFactory = mock(AnalyticsProcessFactory.class);
|
||||||
|
when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
|
||||||
|
auditor = mock(DataFrameAnalyticsAuditor.class);
|
||||||
|
trainedModelProvider = mock(TrainedModelProvider.class);
|
||||||
|
|
||||||
|
task = mock(DataFrameAnalyticsTask.class);
|
||||||
|
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
|
||||||
|
when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class));
|
||||||
|
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
|
||||||
|
dataExtractor = mock(DataFrameDataExtractor.class);
|
||||||
|
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
|
||||||
|
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
|
||||||
|
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
||||||
|
finishHandler = mock(Consumer.class);
|
||||||
|
|
||||||
|
exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
|
||||||
|
|
||||||
|
processManager = new AnalyticsProcessManager(
|
||||||
|
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRunJob_TaskIsStopping() {
|
||||||
|
when(task.isStopping()).thenReturn(true);
|
||||||
|
|
||||||
|
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
|
||||||
|
assertThat(processManager.getProcessContextCount(), equalTo(0));
|
||||||
|
|
||||||
|
verify(finishHandler).accept(null);
|
||||||
|
verifyNoMoreInteractions(finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRunJob_ProcessContextAlreadyExists() {
|
||||||
|
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
|
||||||
|
assertThat(processManager.getProcessContextCount(), equalTo(1));
|
||||||
|
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
|
||||||
|
assertThat(processManager.getProcessContextCount(), equalTo(1));
|
||||||
|
|
||||||
|
verify(finishHandler).accept(exceptionCaptor.capture());
|
||||||
|
verifyNoMoreInteractions(finishHandler);
|
||||||
|
|
||||||
|
Exception e = exceptionCaptor.getValue();
|
||||||
|
assertThat(e.getMessage(), equalTo("[config-id] Could not create process as one already exists"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRunJob_EmptyDataFrame() {
|
||||||
|
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
|
||||||
|
|
||||||
|
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
|
||||||
|
assertThat(processManager.getProcessContextCount(), equalTo(0)); // Make sure the process context did not leak
|
||||||
|
|
||||||
|
InOrder inOrder = inOrder(dataExtractor, executorServiceForProcess, process, finishHandler);
|
||||||
|
inOrder.verify(dataExtractor).collectDataSummary();
|
||||||
|
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
|
||||||
|
inOrder.verify(finishHandler).accept(null);
|
||||||
|
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRunJob_Ok() {
|
||||||
|
processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler);
|
||||||
|
assertThat(processManager.getProcessContextCount(), equalTo(1));
|
||||||
|
|
||||||
|
InOrder inOrder = inOrder(dataExtractor, executorServiceForProcess, process, finishHandler);
|
||||||
|
inOrder.verify(dataExtractor).collectDataSummary();
|
||||||
|
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
|
||||||
|
inOrder.verify(process).isProcessAlive();
|
||||||
|
inOrder.verify(dataExtractor).getFieldNames();
|
||||||
|
inOrder.verify(executorServiceForProcess, times(2)).execute(any()); // 'processData' and 'processResults' threads
|
||||||
|
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcessContext_GetSetFailureReason() {
|
||||||
|
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
|
||||||
|
assertThat(processContext.getFailureReason(), is(nullValue()));
|
||||||
|
|
||||||
|
processContext.setFailureReason("reason1");
|
||||||
|
assertThat(processContext.getFailureReason(), equalTo("reason1"));
|
||||||
|
|
||||||
|
processContext.setFailureReason(null);
|
||||||
|
assertThat(processContext.getFailureReason(), equalTo("reason1"));
|
||||||
|
|
||||||
|
processContext.setFailureReason("reason2");
|
||||||
|
assertThat(processContext.getFailureReason(), equalTo("reason1"));
|
||||||
|
|
||||||
|
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcessContext_StartProcess_ProcessAlreadyKilled() {
|
||||||
|
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
|
||||||
|
processContext.stop();
|
||||||
|
assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(false));
|
||||||
|
|
||||||
|
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcessContext_StartProcess_EmptyDataFrame() {
|
||||||
|
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
|
||||||
|
|
||||||
|
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
|
||||||
|
assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(false));
|
||||||
|
|
||||||
|
InOrder inOrder = inOrder(dataExtractor, process, finishHandler);
|
||||||
|
inOrder.verify(dataExtractor).collectDataSummary();
|
||||||
|
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
|
||||||
|
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcessContext_StartAndStop() throws Exception {
|
||||||
|
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
|
||||||
|
assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(true));
|
||||||
|
processContext.stop();
|
||||||
|
|
||||||
|
InOrder inOrder = inOrder(dataExtractor, process, finishHandler);
|
||||||
|
// startProcess
|
||||||
|
inOrder.verify(dataExtractor).collectDataSummary();
|
||||||
|
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
|
||||||
|
inOrder.verify(process).isProcessAlive();
|
||||||
|
inOrder.verify(dataExtractor).getFieldNames();
|
||||||
|
// stop
|
||||||
|
inOrder.verify(dataExtractor).cancel();
|
||||||
|
inOrder.verify(process).kill();
|
||||||
|
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcessContext_Stop() {
|
||||||
|
AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID);
|
||||||
|
processContext.stop();
|
||||||
|
|
||||||
|
verifyNoMoreInteractions(dataExtractor, process, finishHandler);
|
||||||
|
}
|
||||||
|
}
|
|
@ -200,7 +200,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
|
private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
|
||||||
return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider,
|
return new AnalyticsResultProcessor(
|
||||||
auditor, fieldNames);
|
analyticsConfig, dataFrameRowsJoiner, progressTracker, trainedModelProvider, auditor, fieldNames);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue