From 7da4724b15b24f79207f0282c26e30ff33b6f3ac Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Thu, 30 Mar 2017 16:08:43 +0200 Subject: [PATCH] [ML] Start using AllocatedPersistentTask#updatePersistentStatus(...) instead of PersistentTasksService directly Original commit: elastic/x-pack-elasticsearch@82a7db17e219fc4a5366c3954be20d8f8c3f5340 --- .../xpack/ml/MachineLearning.java | 6 +- .../xpack/ml/action/OpenJobAction.java | 2 +- .../xpack/ml/datafeed/DatafeedJobRunner.java | 31 +++----- .../autodetect/AutodetectCommunicator.java | 15 +--- .../autodetect/AutodetectProcessManager.java | 34 ++++----- .../ml/datafeed/DatafeedJobRunnerTests.java | 46 +++++++----- .../AutodetectCommunicatorTests.java | 2 +- .../AutodetectProcessManagerTests.java | 74 +++++++++---------- 8 files changed, 98 insertions(+), 112 deletions(-) diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index dab75ef8769..ec9833c80d5 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -295,15 +295,15 @@ public class MachineLearning implements ActionPlugin { } NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, threadPool.executor(MachineLearning.THREAD_POOL_NAME)); - PersistentTasksService persistentTasksService = new PersistentTasksService(Settings.EMPTY, clusterService, internalClient); AutodetectProcessManager autodetectProcessManager = new AutodetectProcessManager(settings, internalClient, threadPool, jobManager, jobProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory, - normalizerFactory, persistentTasksService, xContentRegistry); + normalizerFactory, xContentRegistry); DatafeedJobRunner datafeedJobRunner = new DatafeedJobRunner(threadPool, internalClient, clusterService, jobProvider, - System::currentTimeMillis, persistentTasksService, auditor); + System::currentTimeMillis, auditor); InvalidLicenseEnforcer invalidLicenseEnforcer = new InvalidLicenseEnforcer(settings, licenseState, threadPool, datafeedJobRunner, autodetectProcessManager); + PersistentTasksService persistentTasksService = new PersistentTasksService(Settings.EMPTY, clusterService, internalClient); PersistentTasksExecutorRegistry persistentTasksExecutorRegistry = new PersistentTasksExecutorRegistry(Settings.EMPTY, Arrays.asList( new OpenJobAction.OpenJobPersistentTasksExecutor(settings, threadPool, licenseState, persistentTasksService, clusterService, autodetectProcessManager, auditor), diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java index f953fbf633d..e9cf0707fb0 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java @@ -386,7 +386,7 @@ public class OpenJobAction extends Action listener) { JobTask jobTask = (JobTask) task; jobTask.autodetectProcessManager = autodetectProcessManager; - autodetectProcessManager.openJob(request.getJobId(), task.getPersistentTaskId(), request.isIgnoreDowntime(), e2 -> { + autodetectProcessManager.openJob(request.getJobId(), jobTask, request.isIgnoreDowntime(), e2 -> { if (e2 == null) { listener.onResponse(new TransportResponse.Empty()); } else { diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunner.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunner.java index 0ceb0ac6ec1..def8726588a 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunner.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunner.java @@ -33,7 +33,6 @@ import org.elasticsearch.xpack.ml.job.results.Bucket; import org.elasticsearch.xpack.ml.job.results.Result; import org.elasticsearch.xpack.ml.notifications.Auditor; import org.elasticsearch.xpack.ml.utils.DatafeedStateObserver; -import org.elasticsearch.xpack.persistent.PersistentTasksService; import org.elasticsearch.xpack.persistent.PersistentTasksService.PersistentTaskOperationListener; import java.time.Duration; @@ -58,19 +57,17 @@ public class DatafeedJobRunner extends AbstractComponent { private final JobProvider jobProvider; private final ThreadPool threadPool; private final Supplier currentTimeSupplier; - private final PersistentTasksService persistentTasksService; private final Auditor auditor; private final ConcurrentMap runningDatafeeds = new ConcurrentHashMap<>(); public DatafeedJobRunner(ThreadPool threadPool, Client client, ClusterService clusterService, JobProvider jobProvider, - Supplier currentTimeSupplier, PersistentTasksService persistentTasksService, Auditor auditor) { + Supplier currentTimeSupplier, Auditor auditor) { super(Settings.EMPTY); this.client = Objects.requireNonNull(client); this.clusterService = Objects.requireNonNull(clusterService); this.jobProvider = Objects.requireNonNull(jobProvider); this.threadPool = threadPool; this.currentTimeSupplier = Objects.requireNonNull(currentTimeSupplier); - this.persistentTasksService = persistentTasksService; this.auditor = auditor; } @@ -93,12 +90,16 @@ public class DatafeedJobRunner extends AbstractComponent { } Holder holder = createJobDatafeed(datafeed, job, latestFinalBucketEndMs, latestRecordTimeMs, handler, task); runningDatafeeds.put(datafeedId, holder); - updateDatafeedState(task.getPersistentTaskId(), DatafeedState.STARTED, e -> { - if (e != null) { - handler.accept(e); - } else { + task.updatePersistentStatus(DatafeedState.STARTED, new PersistentTaskOperationListener() { + @Override + public void onResponse(long taskId) { innerRun(holder, task.getDatafeedStartTime(), task.getEndTime()); } + + @Override + public void onFailure(Exception e) { + handler.accept(e); + } }); }, handler); } @@ -259,20 +260,6 @@ public class DatafeedJobRunner extends AbstractComponent { }); } - private void updateDatafeedState(long persistentTaskId, DatafeedState datafeedState, Consumer handler) { - persistentTasksService.updateStatus(persistentTaskId, datafeedState, new PersistentTaskOperationListener() { - @Override - public void onResponse(long taskId) { - handler.accept(null); - } - - @Override - public void onFailure(Exception e) { - handler.accept(e); - } - }); - } - private static Duration getFrequencyOrDefault(DatafeedConfig datafeed, Job job) { TimeValue frequency = datafeed.getFrequency(); TimeValue bucketSpan = job.getAnalysisConfig().getBucketSpan(); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java index ff9046ebba5..e44b8b56e41 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java @@ -47,7 +47,6 @@ public class AutodetectCommunicator implements Closeable { private static final Logger LOGGER = Loggers.getLogger(AutodetectCommunicator.class); private static final Duration FLUSH_PROCESS_CHECK_FREQUENCY = Duration.ofSeconds(1); - private final long taskId; private final Job job; private final DataCountsReporter dataCountsReporter; private final AutodetectProcess autodetectProcess; @@ -55,13 +54,11 @@ public class AutodetectCommunicator implements Closeable { private final Consumer handler; final AtomicReference inUse = new AtomicReference<>(); - private NamedXContentRegistry xContentRegistry; + private final NamedXContentRegistry xContentRegistry; - public AutodetectCommunicator(long taskId, Job job, AutodetectProcess process, - DataCountsReporter dataCountsReporter, - AutoDetectResultProcessor autoDetectResultProcessor, Consumer handler, - NamedXContentRegistry xContentRegistry) { - this.taskId = taskId; + AutodetectCommunicator(Job job, AutodetectProcess process, DataCountsReporter dataCountsReporter, + AutoDetectResultProcessor autoDetectResultProcessor, Consumer handler, + NamedXContentRegistry xContentRegistry) { this.job = job; this.autodetectProcess = process; this.dataCountsReporter = dataCountsReporter; @@ -185,10 +182,6 @@ public class AutodetectCommunicator implements Closeable { return dataCountsReporter.runningTotalStats(); } - public long getTaskId() { - return taskId; - } - private T checkAndRun(Supplier errorMessage, CheckedSupplier callback, boolean wait) throws IOException { CountDownLatch latch = new CountDownLatch(1); if (inUse.compareAndSet(null, latch)) { diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java index d0b6f6e8ef7..931f0dde452 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.action.OpenJobAction.JobTask; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.config.Job; import org.elasticsearch.xpack.ml.job.config.JobState; @@ -41,7 +42,6 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.Renormalizer; import org.elasticsearch.xpack.ml.job.process.normalizer.ScoresUpdater; import org.elasticsearch.xpack.ml.job.process.normalizer.ShortCircuitingRenormalizer; import org.elasticsearch.xpack.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.persistent.PersistentTasksService; import org.elasticsearch.xpack.persistent.PersistentTasksService.PersistentTaskOperationListener; import java.io.IOException; @@ -77,7 +77,6 @@ public class AutodetectProcessManager extends AbstractComponent { private final JobResultsPersister jobResultsPersister; private final JobDataCountsPersister jobDataCountsPersister; - private final PersistentTasksService persistentTasksService; private final ConcurrentMap autoDetectCommunicatorByJob; @@ -89,7 +88,7 @@ public class AutodetectProcessManager extends AbstractComponent { JobManager jobManager, JobProvider jobProvider, JobResultsPersister jobResultsPersister, JobDataCountsPersister jobDataCountsPersister, AutodetectProcessFactory autodetectProcessFactory, NormalizerFactory normalizerFactory, - PersistentTasksService persistentTasksService, NamedXContentRegistry xContentRegistry) { + NamedXContentRegistry xContentRegistry) { super(settings); this.client = client; this.threadPool = threadPool; @@ -102,7 +101,6 @@ public class AutodetectProcessManager extends AbstractComponent { this.jobResultsPersister = jobResultsPersister; this.jobDataCountsPersister = jobDataCountsPersister; - this.persistentTasksService = persistentTasksService; this.autoDetectCommunicatorByJob = new ConcurrentHashMap<>(); } @@ -207,7 +205,7 @@ public class AutodetectProcessManager extends AbstractComponent { // TODO check for errors from autodetects } - public void openJob(String jobId, long taskId, boolean ignoreDowntime, Consumer handler) { + public void openJob(String jobId, JobTask jobTask, boolean ignoreDowntime, Consumer handler) { Job job = jobManager.getJobOrThrowIfUnknown(jobId); jobProvider.getAutodetectParams(job, params -> { // We need to fork, otherwise we restore model state from a network thread (several GET api calls): @@ -221,9 +219,9 @@ public class AutodetectProcessManager extends AbstractComponent { protected void doRun() throws Exception { try { AutodetectCommunicator communicator = autoDetectCommunicatorByJob.computeIfAbsent(jobId, id -> - create(id, taskId, params, ignoreDowntime, handler)); + create(id, jobTask, params, ignoreDowntime, handler)); communicator.writeJobInputHeader(); - setJobState(taskId, jobId, JobState.OPENED); + setJobState(jobTask, JobState.OPENED); } catch (Exception e1) { if (e1 instanceof ElasticsearchStatusException) { logger.info(e1.getMessage()); @@ -231,17 +229,17 @@ public class AutodetectProcessManager extends AbstractComponent { String msg = String.format(Locale.ROOT, "[%s] exception while opening job", jobId); logger.error(msg, e1); } - setJobState(taskId, JobState.FAILED, e2 -> handler.accept(e1)); + setJobState(jobTask, JobState.FAILED, e2 -> handler.accept(e1)); } } }); }, e1 -> { logger.warn("Failed to gather information required to open job [" + jobId + "]", e1); - setJobState(taskId, JobState.FAILED, e2 -> handler.accept(e1)); + setJobState(jobTask, JobState.FAILED, e2 -> handler.accept(e1)); }); } - AutodetectCommunicator create(String jobId, long taskId, AutodetectParams autodetectParams, + AutodetectCommunicator create(String jobId, JobTask jobTask, AutodetectParams autodetectParams, boolean ignoreDowntime, Consumer handler) { if (autoDetectCommunicatorByJob.size() == maxAllowedRunningJobs) { throw new ElasticsearchStatusException("max running job capacity [" + maxAllowedRunningJobs + "] reached", @@ -269,7 +267,7 @@ public class AutodetectProcessManager extends AbstractComponent { AutodetectProcess process = autodetectProcessFactory.createAutodetectProcess(job, autodetectParams.modelSnapshot(), autodetectParams.quantiles(), autodetectParams.filters(), ignoreDowntime, - executorService, () -> setJobState(taskId, jobId, JobState.FAILED)); + executorService, () -> setJobState(jobTask, JobState.FAILED)); boolean usePerPartitionNormalization = job.getAnalysisConfig().getUsePerPartitionNormalization(); AutoDetectResultProcessor processor = new AutoDetectResultProcessor( client, jobId, renormalizer, jobResultsPersister, autodetectParams.modelSizeStats()); @@ -285,7 +283,7 @@ public class AutodetectProcessManager extends AbstractComponent { } throw e; } - return new AutodetectCommunicator(taskId, job, process, dataCountsReporter, processor, + return new AutodetectCommunicator(job, process, dataCountsReporter, processor, handler, xContentRegistry); } } @@ -335,22 +333,22 @@ public class AutodetectProcessManager extends AbstractComponent { return Optional.of(Duration.between(communicator.getProcessStartTime(), ZonedDateTime.now())); } - private void setJobState(long taskId, String jobId, JobState state) { - persistentTasksService.updateStatus(taskId, state, new PersistentTaskOperationListener() { + private void setJobState(JobTask jobTask, JobState state) { + jobTask.updatePersistentStatus(state, new PersistentTaskOperationListener() { @Override public void onResponse(long taskId) { - logger.info("Successfully set job state to [{}] for job [{}]", state, jobId); + logger.info("Successfully set job state to [{}] for job [{}]", state, jobTask.getJobId()); } @Override public void onFailure(Exception e) { - logger.error("Could not set job state to [" + state + "] for job [" + jobId + "]", e); + logger.error("Could not set job state to [" + state + "] for job [" + jobTask.getJobId() + "]", e); } }); } - public void setJobState(long taskId, JobState state, CheckedConsumer handler) { - persistentTasksService.updateStatus(taskId, state, new PersistentTaskOperationListener() { + public void setJobState(JobTask jobTask, JobState state, CheckedConsumer handler) { + jobTask.updatePersistentStatus(state, new PersistentTaskOperationListener() { @Override public void onResponse(long taskId) { try { diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunnerTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunnerTests.java index 925aee28ca1..b965621fba8 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunnerTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedJobRunnerTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.ml.action.FlushJobAction; import org.elasticsearch.xpack.ml.action.OpenJobAction; import org.elasticsearch.xpack.ml.action.PostDataAction; import org.elasticsearch.xpack.ml.action.StartDatafeedAction; +import org.elasticsearch.xpack.ml.action.StartDatafeedAction.DatafeedTask; import org.elasticsearch.xpack.ml.action.StartDatafeedActionTests; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory; @@ -44,7 +45,6 @@ import org.elasticsearch.xpack.ml.notifications.AuditMessage; import org.elasticsearch.xpack.ml.notifications.Auditor; import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData.PersistentTask; -import org.elasticsearch.xpack.persistent.PersistentTasksService; import org.elasticsearch.xpack.persistent.PersistentTasksService.PersistentTaskOperationListener; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -70,6 +70,7 @@ import static org.mockito.Matchers.same; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -85,7 +86,6 @@ public class DatafeedJobRunnerTests extends ESTestCase { private DatafeedJobRunner datafeedJobRunner; private long currentTime = 120000; private Auditor auditor; - private PersistentTasksService persistentTasksService; @Before @SuppressWarnings("unchecked") @@ -142,9 +142,7 @@ public class DatafeedJobRunnerTests extends ESTestCase { when(client.execute(same(PostDataAction.INSTANCE), any())).thenReturn(jobDataFuture); when(client.execute(same(FlushJobAction.INSTANCE), any())).thenReturn(flushJobFuture); - persistentTasksService = mock(PersistentTasksService.class); - datafeedJobRunner = new DatafeedJobRunner(threadPool, client, clusterService, jobProvider, () -> currentTime, - persistentTasksService, auditor) { + datafeedJobRunner = new DatafeedJobRunner(threadPool, client, clusterService, jobProvider, () -> currentTime, auditor) { @Override DataExtractorFactory createDataExtractorFactory(DatafeedConfig datafeedConfig, Job job) { return dataExtractorFactory; @@ -157,12 +155,6 @@ public class DatafeedJobRunnerTests extends ESTestCase { consumer.accept(new ResourceNotFoundException("dummy")); return null; }).when(jobProvider).bucketsViaInternalClient(any(), any(), any(), any()); - doAnswer(invocationOnMock -> { - @SuppressWarnings("rawtypes") - PersistentTaskOperationListener listener = (PersistentTaskOperationListener) invocationOnMock.getArguments()[2]; - listener.onResponse(0L); - return null; - }).when(persistentTasksService).updateStatus(anyLong(), any(), any()); } public void testLookbackOnly_WarnsWhenNoDataIsRetrieved() throws Exception { @@ -171,7 +163,7 @@ public class DatafeedJobRunnerTests extends ESTestCase { when(dataExtractor.hasNext()).thenReturn(true).thenReturn(false); when(dataExtractor.next()).thenReturn(Optional.empty()); Consumer handler = mockConsumer(); - StartDatafeedAction.DatafeedTask task = createDatafeedTask("datafeed_id", 0L, 60000L); + DatafeedTask task = createDatafeedTask("datafeed_id", 0L, 60000L); datafeedJobRunner.run(task, handler); verify(threadPool, times(1)).executor(MachineLearning.DATAFEED_RUNNER_THREAD_POOL_NAME); @@ -193,7 +185,7 @@ public class DatafeedJobRunnerTests extends ESTestCase { new Date(0), new Date(0), new Date(0), new Date(0), new Date(0)); when(jobDataFuture.actionGet()).thenReturn(new PostDataAction.Response(dataCounts)); Consumer handler = mockConsumer(); - StartDatafeedAction.DatafeedTask task = createDatafeedTask("datafeed_id", 0L, 60000L); + DatafeedTask task = createDatafeedTask("datafeed_id", 0L, 60000L); datafeedJobRunner.run(task, handler); verify(threadPool, times(1)).executor(MachineLearning.DATAFEED_RUNNER_THREAD_POOL_NAME); @@ -223,7 +215,7 @@ public class DatafeedJobRunnerTests extends ESTestCase { new Date(0), new Date(0), new Date(0), new Date(0), new Date(0)); when(jobDataFuture.actionGet()).thenReturn(new PostDataAction.Response(dataCounts)); Consumer handler = mockConsumer(); - StartDatafeedAction.DatafeedTask task = createDatafeedTask("datafeed_id", 0L, 60000L); + DatafeedTask task = createDatafeedTask("datafeed_id", 0L, 60000L); datafeedJobRunner.run(task, handler); verify(threadPool, times(1)).executor(MachineLearning.DATAFEED_RUNNER_THREAD_POOL_NAME); @@ -258,7 +250,7 @@ public class DatafeedJobRunnerTests extends ESTestCase { when(dataExtractorFactory.newExtractor(anyLong(), anyLong())).thenReturn(dataExtractor); when(dataExtractor.hasNext()).thenReturn(false); Consumer handler = mockConsumer(); - StartDatafeedAction.DatafeedTask task = createDatafeedTask("datafeed_id", 0L, null); + DatafeedTask task = createDatafeedTask("datafeed_id", 0L, null); DatafeedJobRunner.Holder holder = datafeedJobRunner.createJobDatafeed(datafeedConfig, job, 100, 100, handler, task); datafeedJobRunner.doDatafeedRealtime(10L, "foo", holder); @@ -282,8 +274,9 @@ public class DatafeedJobRunnerTests extends ESTestCase { Consumer handler = mockConsumer(); boolean cancelled = randomBoolean(); StartDatafeedAction.Request startDatafeedRequest = new StartDatafeedAction.Request("datafeed_id", 0L); - StartDatafeedAction.DatafeedTask task = StartDatafeedActionTests.createDatafeedTask(1, "type", "action", null, + DatafeedTask task = StartDatafeedActionTests.createDatafeedTask(1, "type", "action", null, startDatafeedRequest, datafeedJobRunner); + task = spyDatafeedTask(task); datafeedJobRunner.run(task, handler); verify(threadPool, times(1)).executor(MachineLearning.DATAFEED_RUNNER_THREAD_POOL_NAME); @@ -316,11 +309,17 @@ public class DatafeedJobRunnerTests extends ESTestCase { return builder; } - private static StartDatafeedAction.DatafeedTask createDatafeedTask(String datafeedId, long startTime, Long endTime) { - StartDatafeedAction.DatafeedTask task = mock(StartDatafeedAction.DatafeedTask.class); + private static DatafeedTask createDatafeedTask(String datafeedId, long startTime, Long endTime) { + DatafeedTask task = mock(DatafeedTask.class); when(task.getDatafeedId()).thenReturn(datafeedId); when(task.getDatafeedStartTime()).thenReturn(startTime); when(task.getEndTime()).thenReturn(endTime); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + PersistentTaskOperationListener listener = (PersistentTaskOperationListener) invocationOnMock.getArguments()[1]; + listener.onResponse(0L); + return null; + }).when(task).updatePersistentStatus(any(), any()); return task; } @@ -328,4 +327,15 @@ public class DatafeedJobRunnerTests extends ESTestCase { private Consumer mockConsumer() { return mock(Consumer.class); } + + private DatafeedTask spyDatafeedTask(DatafeedTask task) { + task = spy(task); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + PersistentTaskOperationListener listener = (PersistentTaskOperationListener) invocationOnMock.getArguments()[1]; + listener.onResponse(0L); + return null; + }).when(task).updatePersistentStatus(any(), any()); + return task; + } } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java index 40b724844c0..8596ca0a72c 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java @@ -151,7 +151,7 @@ public class AutodetectCommunicatorTests extends ESTestCase { return null; }).when(executorService).execute(any(Runnable.class)); DataCountsReporter dataCountsReporter = mock(DataCountsReporter.class); - return new AutodetectCommunicator(0L, createJobDetails(), autodetectProcess, + return new AutodetectCommunicator(createJobDetails(), autodetectProcess, dataCountsReporter, autoDetectResultProcessor, e -> { }, new NamedXContentRegistry(Collections.emptyList())); } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java index 49a02249365..e4c6a493db0 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.ml.action.OpenJobAction.JobTask; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.ml.job.config.DataDescription; @@ -37,7 +38,6 @@ import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSizeStats; import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.job.process.autodetect.state.Quantiles; import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory; -import org.elasticsearch.xpack.persistent.PersistentTasksService; import org.junit.Before; import org.mockito.Mockito; @@ -63,7 +63,6 @@ import static org.elasticsearch.mock.orig.Mockito.when; import static org.hamcrest.core.IsEqual.equalTo; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; -import static org.mockito.Matchers.anyLong; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; @@ -109,15 +108,15 @@ public class AutodetectProcessManagerTests extends ESTestCase { public void testOpenJob() { Client client = mock(Client.class); - PersistentTasksService persistentTasksService = mock(PersistentTasksService.class); AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); when(jobManager.getJobOrThrowIfUnknown("foo")).thenReturn(createJobDetails("foo")); - AutodetectProcessManager manager = createManager(communicator, client, persistentTasksService); + AutodetectProcessManager manager = createManager(communicator, client); - manager.openJob("foo", 1L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); assertEquals(1, manager.numberOfOpenJobs()); assertTrue(manager.jobHasActiveAutodetectProcess("foo")); - verify(persistentTasksService).updateStatus(eq(1L), eq(JobState.OPENED), any()); + verify(jobTask).updatePersistentStatus(eq(JobState.OPENED), any()); } public void testOpenJob_exceedMaxNumJobs() { @@ -127,7 +126,6 @@ public class AutodetectProcessManagerTests extends ESTestCase { when(jobManager.getJobOrThrowIfUnknown("foobar")).thenReturn(createJobDetails("foobar")); Client client = mock(Client.class); - PersistentTasksService persistentTasksService = mock(PersistentTasksService.class); ThreadPool threadPool = mock(ThreadPool.class); ThreadPool.Cancellable cancellable = mock(ThreadPool.Cancellable.class); when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); @@ -141,33 +139,29 @@ public class AutodetectProcessManagerTests extends ESTestCase { settings.put(AutodetectProcessManager.MAX_RUNNING_JOBS_PER_NODE.getKey(), 3); AutodetectProcessManager manager = spy(new AutodetectProcessManager(settings.build(), client, threadPool, jobManager, jobProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory, - normalizerFactory, persistentTasksService, - new NamedXContentRegistry(Collections.emptyList()))); + normalizerFactory, new NamedXContentRegistry(Collections.emptyList()))); - DataCounts dataCounts = new DataCounts("foo"); - ModelSnapshot modelSnapshot = new ModelSnapshot.Builder("foo").build(); - Quantiles quantiles = new Quantiles("foo", new Date(), "state"); - Set filters = new HashSet<>(); doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") CheckedConsumer consumer = (CheckedConsumer) invocationOnMock.getArguments()[2]; consumer.accept(null); return null; - }).when(manager).setJobState(anyLong(), eq(JobState.FAILED), any()); + }).when(manager).setJobState(any(), eq(JobState.FAILED), any()); - manager.openJob("foo", 1L, false, e -> {}); - manager.openJob("bar", 2L, false, e -> {}); - manager.openJob("baz", 3L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); + manager.openJob("bar", jobTask, false, e -> {}); + manager.openJob("baz", jobTask, false, e -> {}); assertEquals(3, manager.numberOfOpenJobs()); Exception[] holder = new Exception[1]; - manager.openJob("foobar", 4L, false, e -> holder[0] = e); + manager.openJob("foobar", jobTask, false, e -> holder[0] = e); Exception e = holder[0]; assertEquals("max running job capacity [3] reached", e.getMessage()); manager.closeJob("baz", false, null); assertEquals(2, manager.numberOfOpenJobs()); - manager.openJob("foobar", 4L, false, e1 -> {}); + manager.openJob("foobar", jobTask, false, e1 -> {}); assertEquals(3, manager.numberOfOpenJobs()); } @@ -176,8 +170,9 @@ public class AutodetectProcessManagerTests extends ESTestCase { AutodetectProcessManager manager = createManager(communicator); assertEquals(0, manager.numberOfOpenJobs()); + JobTask jobTask = mock(JobTask.class); DataLoadParams params = new DataLoadParams(TimeRange.builder().build(), Optional.empty()); - manager.openJob("foo", 1L, false, e -> {}); + manager.openJob("foo", jobTask, false, e -> {}); manager.processData("foo", createInputStream(""), randomFrom(XContentType.values()), params); assertEquals(1, manager.numberOfOpenJobs()); @@ -193,7 +188,8 @@ public class AutodetectProcessManagerTests extends ESTestCase { doThrow(new IOException("blah")).when(communicator).writeToJob(inputStream, xContentType, params); - manager.openJob("foo", 1L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); ESTestCase.expectThrows(ElasticsearchException.class, () -> manager.processData("foo", inputStream, xContentType, params)); } @@ -203,7 +199,8 @@ public class AutodetectProcessManagerTests extends ESTestCase { AutodetectProcessManager manager = createManager(communicator); assertEquals(0, manager.numberOfOpenJobs()); - manager.openJob("foo", 1L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); manager.processData("foo", createInputStream(""), randomFrom(XContentType.values()), mock(DataLoadParams.class)); @@ -220,7 +217,8 @@ public class AutodetectProcessManagerTests extends ESTestCase { DataLoadParams params = new DataLoadParams(TimeRange.builder().startTime("1000").endTime("2000").build(), Optional.empty()); InputStream inputStream = createInputStream(""); - manager.openJob("foo", 1L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); manager.processData("foo", inputStream, xContentType, params); verify(communicator).writeToJob(inputStream, xContentType, params); } @@ -229,8 +227,9 @@ public class AutodetectProcessManagerTests extends ESTestCase { AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); AutodetectProcessManager manager = createManager(communicator); + JobTask jobTask = mock(JobTask.class); InputStream inputStream = createInputStream(""); - manager.openJob("foo", 1L, false, e -> {}); + manager.openJob("foo", jobTask, false, e -> {}); manager.processData("foo", inputStream, randomFrom(XContentType.values()), mock(DataLoadParams.class)); @@ -267,7 +266,8 @@ public class AutodetectProcessManagerTests extends ESTestCase { AutodetectProcessManager manager = createManager(communicator); assertFalse(manager.jobHasActiveAutodetectProcess("foo")); - manager.openJob("foo", 1L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); manager.processData("foo", createInputStream(""), randomFrom(XContentType.values()), mock(DataLoadParams.class)); @@ -280,8 +280,9 @@ public class AutodetectProcessManagerTests extends ESTestCase { when(communicator.writeToJob(any(), any(), any())).thenReturn(new DataCounts("foo")); AutodetectProcessManager manager = createManager(communicator); + JobTask jobTask = mock(JobTask.class); + manager.openJob("foo", jobTask, false, e -> {}); InputStream inputStream = createInputStream(""); - manager.openJob("foo", 1L, false, e -> {}); DataCounts dataCounts = manager.processData("foo", inputStream, randomFrom(XContentType.values()), mock(DataLoadParams.class)); @@ -296,17 +297,16 @@ public class AutodetectProcessManagerTests extends ESTestCase { when(threadPool.executor(anyString())).thenReturn(executorService); when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(mock(ThreadPool.Cancellable.class)); when(jobManager.getJobOrThrowIfUnknown("my_id")).thenReturn(createJobDetails("my_id")); - PersistentTasksService persistentTasksService = mock(PersistentTasksService.class); AutodetectProcess autodetectProcess = mock(AutodetectProcess.class); AutodetectProcessFactory autodetectProcessFactory = (j, modelSnapshot, quantiles, filters, i, e, onProcessCrash) -> autodetectProcess; AutodetectProcessManager manager = new AutodetectProcessManager(Settings.EMPTY, client, threadPool, jobManager, jobProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory, - normalizerFactory, persistentTasksService, - new NamedXContentRegistry(Collections.emptyList())); + normalizerFactory, new NamedXContentRegistry(Collections.emptyList())); + JobTask jobTask = mock(JobTask.class); expectThrows(EsRejectedExecutionException.class, - () -> manager.create("my_id", 1L, buildAutodetectParams(), false, e -> {})); + () -> manager.create("my_id", jobTask, buildAutodetectParams(), false, e -> {})); verify(autodetectProcess, times(1)).close(); } @@ -322,27 +322,25 @@ public class AutodetectProcessManagerTests extends ESTestCase { private AutodetectProcessManager createManager(AutodetectCommunicator communicator) { Client client = mock(Client.class); - PersistentTasksService persistentTasksService = mock(PersistentTasksService.class); - return createManager(communicator, client, persistentTasksService); + return createManager(communicator, client); } - private AutodetectProcessManager createManager(AutodetectCommunicator communicator, Client client, - PersistentTasksService persistentTasksService) { + private AutodetectProcessManager createManager(AutodetectCommunicator communicator, Client client) { ThreadPool threadPool = mock(ThreadPool.class); when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService()); AutodetectProcessFactory autodetectProcessFactory = mock(AutodetectProcessFactory.class); AutodetectProcessManager manager = new AutodetectProcessManager(Settings.EMPTY, client, threadPool, jobManager, jobProvider, jobResultsPersister, jobDataCountsPersister, - autodetectProcessFactory, normalizerFactory, persistentTasksService, - new NamedXContentRegistry(Collections.emptyList())); + autodetectProcessFactory, normalizerFactory, new NamedXContentRegistry(Collections.emptyList())); manager = spy(manager); - doReturn(communicator).when(manager).create(any(), anyLong(), eq(buildAutodetectParams()), anyBoolean(), any()); + doReturn(communicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), anyBoolean(), any()); return manager; } private AutodetectProcessManager createManagerAndCallProcessData(AutodetectCommunicator communicator, String jobId) { AutodetectProcessManager manager = createManager(communicator); - manager.openJob(jobId, 1L, false, e -> {}); + JobTask jobTask = mock(JobTask.class); + manager.openJob(jobId, jobTask, false, e -> {}); manager.processData(jobId, createInputStream(""), randomFrom(XContentType.values()), mock(DataLoadParams.class)); return manager;