diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java index 1e35530fe17..9a5d556e1c3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java @@ -16,14 +16,12 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedConsumer; -import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.AbstractRunnable; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -77,20 +75,13 @@ import java.io.InputStream; import java.nio.file.Path; import java.time.Duration; import java.time.ZonedDateTime; -import java.util.ArrayList; import java.util.Date; import java.util.Iterator; -import java.util.List; import java.util.Locale; import java.util.Optional; -import java.util.concurrent.AbstractExecutorService; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -791,99 +782,4 @@ public class AutodetectProcessManager implements ClusterStateListener { upgradeInProgress = MlMetadata.getMlMetadata(event.state()).isUpgradeMode(); } - /* - * The autodetect native process can only handle a single operation at a time. In order to guarantee that, all - * operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each - * operation at a time. - */ - static class AutodetectWorkerExecutorService extends AbstractExecutorService { - - private final ThreadContext contextHolder; - private final CountDownLatch awaitTermination = new CountDownLatch(1); - private final BlockingQueue queue = new LinkedBlockingQueue<>(100); - - private volatile boolean running = true; - - @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") - AutodetectWorkerExecutorService(ThreadContext contextHolder) { - this.contextHolder = contextHolder; - } - - @Override - public void shutdown() { - running = false; - } - - @Override - public List shutdownNow() { - throw new UnsupportedOperationException("not supported"); - } - - @Override - public boolean isShutdown() { - return running == false; - } - - @Override - public boolean isTerminated() { - return awaitTermination.getCount() == 0; - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return awaitTermination.await(timeout, unit); - } - - @Override - public synchronized void execute(Runnable command) { - if (isShutdown()) { - EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true); - if (command instanceof AbstractRunnable) { - ((AbstractRunnable) command).onRejection(rejected); - } else { - throw rejected; - } - } - - boolean added = queue.offer(contextHolder.preserveContext(command)); - if (added == false) { - throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS); - } - } - - void start() { - try { - while (running) { - Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS); - if (runnable != null) { - try { - runnable.run(); - } catch (Exception e) { - logger.error("error handling job operation", e); - } - EsExecutors.rethrowErrors(contextHolder.unwrap(runnable)); - } - } - - synchronized (this) { - // if shutdown with tasks pending notify the handlers - if (queue.isEmpty() == false) { - List notExecuted = new ArrayList<>(); - queue.drainTo(notExecuted); - - for (Runnable runnable : notExecuted) { - if (runnable instanceof AbstractRunnable) { - ((AbstractRunnable) runnable).onRejection( - new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true)); - } - } - } - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } finally { - awaitTermination.countDown(); - } - } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorService.java new file mode 100644 index 00000000000..324815513b9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorService.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.job.process.autodetect; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.SuppressForbidden; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.RestStatus; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/* + * The autodetect native process can only handle a single operation at a time. In order to guarantee that, all + * operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each + * operation at a time. + */ +class AutodetectWorkerExecutorService extends AbstractExecutorService { + + private static final Logger logger = LogManager.getLogger(AutodetectWorkerExecutorService.class); + + private final ThreadContext contextHolder; + private final CountDownLatch awaitTermination = new CountDownLatch(1); + private final BlockingQueue queue = new LinkedBlockingQueue<>(100); + + private volatile boolean running = true; + + @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors") + AutodetectWorkerExecutorService(ThreadContext contextHolder) { + this.contextHolder = contextHolder; + } + + @Override + public void shutdown() { + running = false; + } + + @Override + public List shutdownNow() { + throw new UnsupportedOperationException("not supported"); + } + + @Override + public boolean isShutdown() { + return running == false; + } + + @Override + public boolean isTerminated() { + return awaitTermination.getCount() == 0; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return awaitTermination.await(timeout, unit); + } + + @Override + public synchronized void execute(Runnable command) { + if (isShutdown()) { + EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true); + if (command instanceof AbstractRunnable) { + ((AbstractRunnable) command).onRejection(rejected); + } else { + throw rejected; + } + } + + boolean added = queue.offer(contextHolder.preserveContext(command)); + if (added == false) { + throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS); + } + } + + void start() { + try { + while (running) { + Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS); + if (runnable != null) { + try { + runnable.run(); + } catch (Exception e) { + logger.error("error handling job operation", e); + } + EsExecutors.rethrowErrors(contextHolder.unwrap(runnable)); + } + } + + synchronized (this) { + // if shutdown with tasks pending notify the handlers + if (queue.isEmpty() == false) { + List notExecuted = new ArrayList<>(); + queue.drainTo(notExecuted); + + for (Runnable runnable : notExecuted) { + if (runnable instanceof AbstractRunnable) { + ((AbstractRunnable) runnable).onRejection( + new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true)); + } + } + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + awaitTermination.countDown(); + } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java index ee02e5237c6..9a147dfd1bc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -26,7 +25,6 @@ import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.junit.annotations.TestLogging; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -50,17 +48,14 @@ import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzerTests import org.elasticsearch.xpack.ml.job.persistence.JobDataCountsPersister; import org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; -import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager.AutodetectWorkerExecutorService; import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.TimeRange; import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory; import org.elasticsearch.xpack.ml.notifications.Auditor; -import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -76,11 +71,9 @@ import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -93,7 +86,6 @@ import static org.elasticsearch.mock.orig.Mockito.times; import static org.elasticsearch.mock.orig.Mockito.verify; import static org.elasticsearch.mock.orig.Mockito.verifyNoMoreInteractions; import static org.elasticsearch.mock.orig.Mockito.when; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; @@ -103,6 +95,7 @@ import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -115,11 +108,15 @@ import static org.mockito.Mockito.spy; public class AutodetectProcessManagerTests extends ESTestCase { private Environment environment; + private Client client; + private ThreadPool threadPool; private AnalysisRegistry analysisRegistry; private JobManager jobManager; private JobResultsProvider jobResultsProvider; private JobResultsPersister jobResultsPersister; private JobDataCountsPersister jobDataCountsPersister; + private AutodetectCommunicator autodetectCommunicator; + private AutodetectProcessFactory autodetectFactory; private NormalizerFactory normalizerFactory; private Auditor auditor; private ClusterState clusterState; @@ -131,18 +128,24 @@ public class AutodetectProcessManagerTests extends ESTestCase { private Quantiles quantiles = new Quantiles("foo", new Date(), "state"); private Set filters = new HashSet<>(); - private ThreadPool threadPool; - @Before public void setup() throws Exception { Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir()).build(); environment = TestEnvironment.newEnvironment(settings); + client = mock(Client.class); + + threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService()); + analysisRegistry = CategorizationAnalyzerTests.buildTestAnalysisRegistry(environment); jobManager = mock(JobManager.class); jobResultsProvider = mock(JobResultsProvider.class); jobResultsPersister = mock(JobResultsPersister.class); when(jobResultsPersister.bulkPersisterBuilder(any())).thenReturn(mock(JobResultsPersister.Builder.class)); jobDataCountsPersister = mock(JobDataCountsPersister.class); + autodetectCommunicator = mock(AutodetectCommunicator.class); + autodetectFactory = mock(AutodetectProcessFactory.class); normalizerFactory = mock(NormalizerFactory.class); auditor = mock(Auditor.class); clusterService = mock(ClusterService.class); @@ -170,25 +173,16 @@ public class AutodetectProcessManagerTests extends ESTestCase { handler.accept(buildAutodetectParams()); return null; }).when(jobResultsProvider).getAutodetectParams(any(), any(), any()); - - threadPool = new TestThreadPool("AutodetectProcessManagerTests"); - } - - @After - public void stopThreadPool() { - terminate(threadPool); } public void testOpenJob() { - Client client = mock(Client.class); - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; listener.onResponse(createJobDetails("foo")); return null; }).when(jobManager).getJob(eq("foo"), any()); - AutodetectProcessManager manager = createManager(communicator, client); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); @@ -200,8 +194,6 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testOpenJob_withoutVersion() { - Client client = mock(Client.class); - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); Job.Builder jobBuilder = new Job.Builder(createJobDetails("no_version")); jobBuilder.setJobVersion(null); Job job = jobBuilder.build(); @@ -214,7 +206,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { return null; }).when(jobManager).getJob(eq(job.getId()), any()); - AutodetectProcessManager manager = createManager(communicator, client); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn(job.getId()); AtomicReference errorHolder = new AtomicReference<>(); @@ -235,25 +227,22 @@ public class AutodetectProcessManagerTests extends ESTestCase { }).when(jobManager).getJob(eq(jobId), any()); } - Client client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ThreadPool.Cancellable cancellable = mock(ThreadPool.Cancellable.class); when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable); - ExecutorService executorService = mock(ExecutorService.class); - Future future = mock(Future.class); - when(executorService.submit(any(Callable.class))).thenReturn(future); - when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService()); + AutodetectProcess autodetectProcess = mock(AutodetectProcess.class); when(autodetectProcess.isProcessAlive()).thenReturn(true); when(autodetectProcess.readAutodetectResults()).thenReturn(Collections.emptyIterator()); - AutodetectProcessFactory autodetectProcessFactory = - (j, autodetectParams, e, onProcessCrash) -> autodetectProcess; + + autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess; Settings.Builder settings = Settings.builder(); settings.put(MachineLearning.MAX_OPEN_JOBS_PER_NODE.getKey(), 3); - AutodetectProcessManager manager = spy(new AutodetectProcessManager(environment, settings.build(), client, threadPool, - jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory, - normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService)); + AutodetectProcessManager manager = createSpyManager(settings.build()); + doCallRealMethod().when(manager).create(any(), any(), any(), any()); + + ExecutorService executorService = mock(ExecutorService.class); + Future future = mock(Future.class); + when(executorService.submit(any(Callable.class))).thenReturn(future); doReturn(executorService).when(manager).createAutodetectExecutorService(any()); doAnswer(invocationOnMock -> { @@ -293,8 +282,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testProcessData() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); assertEquals(0, manager.numberOfOpenJobs()); JobTask jobTask = mock(JobTask.class); @@ -307,8 +295,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testProcessDataThrowsElasticsearchStatusException_onIoException() { - AutodetectCommunicator communicator = Mockito.mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); DataLoadParams params = mock(DataLoadParams.class); InputStream inputStream = createInputStream(""); @@ -318,7 +305,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { BiConsumer handler = (BiConsumer) invocationOnMock.getArguments()[4]; handler.accept(null, new IOException("blah")); return null; - }).when(communicator).writeToJob(eq(inputStream), same(analysisRegistry), same(xContentType), eq(params), any()); + }).when(autodetectCommunicator).writeToJob(eq(inputStream), same(analysisRegistry), same(xContentType), eq(params), any()); JobTask jobTask = mock(JobTask.class); @@ -330,8 +317,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testCloseJob() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); assertEquals(0, manager.numberOfOpenJobs()); JobTask jobTask = mock(JobTask.class); @@ -350,7 +336,6 @@ public class AutodetectProcessManagerTests extends ESTestCase { // interleaved in the AutodetectProcessManager.close() call @TestLogging("org.elasticsearch.xpack.ml.job.process.autodetect:DEBUG") public void testCanCloseClosingJob() throws Exception { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); AtomicInteger numberOfCommunicatorCloses = new AtomicInteger(0); doAnswer(invocationOnMock -> { numberOfCommunicatorCloses.incrementAndGet(); @@ -358,8 +343,8 @@ public class AutodetectProcessManagerTests extends ESTestCase { // the middle of the AutodetectProcessManager.close() method Thread.yield(); return null; - }).when(communicator).close(anyBoolean(), anyString()); - AutodetectProcessManager manager = createManager(communicator); + }).when(autodetectCommunicator).close(anyBoolean(), anyString()); + AutodetectProcessManager manager = createSpyManager(); assertEquals(0, manager.numberOfOpenJobs()); JobTask jobTask = mock(JobTask.class); @@ -395,19 +380,18 @@ public class AutodetectProcessManagerTests extends ESTestCase { CountDownLatch closeStartedLatch = new CountDownLatch(1); CountDownLatch killLatch = new CountDownLatch(1); CountDownLatch closeInterruptedLatch = new CountDownLatch(1); - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); doAnswer(invocationOnMock -> { closeStartedLatch.countDown(); if (killLatch.await(3, TimeUnit.SECONDS)) { closeInterruptedLatch.countDown(); } return null; - }).when(communicator).close(anyBoolean(), anyString()); + }).when(autodetectCommunicator).close(anyBoolean(), anyString()); doAnswer(invocationOnMock -> { killLatch.countDown(); return null; - }).when(communicator).killProcess(anyBoolean(), anyBoolean(), anyBoolean()); - AutodetectProcessManager manager = createManager(communicator); + }).when(autodetectCommunicator).killProcess(anyBoolean(), anyBoolean(), anyBoolean()); + AutodetectProcessManager manager = createSpyManager(); assertEquals(0, manager.numberOfOpenJobs()); JobTask jobTask = mock(JobTask.class); @@ -433,8 +417,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testBucketResetMessageIsSent() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); XContentType xContentType = randomFrom(XContentType.values()); DataLoadParams params = new DataLoadParams(TimeRange.builder().startTime("1000").endTime("2000").build(), Optional.empty()); @@ -443,12 +426,11 @@ public class AutodetectProcessManagerTests extends ESTestCase { when(jobTask.getJobId()).thenReturn("foo"); manager.openJob(jobTask, clusterState, (e, b) -> {}); manager.processData(jobTask, analysisRegistry, inputStream, xContentType, params, (dataCounts1, e) -> {}); - verify(communicator).writeToJob(same(inputStream), same(analysisRegistry), same(xContentType), same(params), any()); + verify(autodetectCommunicator).writeToJob(same(inputStream), same(analysisRegistry), same(xContentType), same(params), any()); } public void testFlush() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); @@ -460,12 +442,11 @@ public class AutodetectProcessManagerTests extends ESTestCase { FlushJobParams params = FlushJobParams.builder().build(); manager.flushJob(jobTask, params, ActionListener.wrap(flushAcknowledgement -> {}, e -> fail(e.getMessage()))); - verify(communicator).flushJob(same(params), any()); + verify(autodetectCommunicator).flushJob(same(params), any()); } public void testFlushThrows() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManagerAndCallProcessData(communicator, "foo"); + AutodetectProcessManager manager = createSpyManagerAndCallProcessData("foo"); FlushJobParams params = FlushJobParams.builder().build(); doAnswer(invocationOnMock -> { @@ -473,7 +454,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { BiConsumer handler = (BiConsumer) invocationOnMock.getArguments()[1]; handler.accept(null, new IOException("blah")); return null; - }).when(communicator).flushJob(same(params), any()); + }).when(autodetectCommunicator).flushJob(same(params), any()); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); @@ -483,12 +464,11 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testCloseThrows() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); // let the communicator throw, simulating a problem with the underlying // autodetect, e.g. a crash - doThrow(Exception.class).when(communicator).close(anyBoolean(), anyString()); + doThrow(Exception.class).when(autodetectCommunicator).close(anyBoolean(), anyString()); // create a jobtask JobTask jobTask = mock(JobTask.class); @@ -507,8 +487,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testWriteUpdateProcessMessage() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManagerAndCallProcessData(communicator, "foo"); + AutodetectProcessManager manager = createSpyManagerAndCallProcessData("foo"); ModelPlotConfig modelConfig = mock(ModelPlotConfig.class); List rules = Collections.singletonList(mock(DetectionRule.class)); List detectorUpdates = Collections.singletonList(new JobUpdate.DetectorUpdate(2, null, rules)); @@ -519,7 +498,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { manager.writeUpdateProcessMessage(jobTask, updateParams, e -> {}); ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateProcessMessage.class); - verify(communicator).writeUpdateProcessMessage(captor.capture(), any()); + verify(autodetectCommunicator).writeUpdateProcessMessage(captor.capture(), any()); UpdateProcessMessage updateProcessMessage = captor.getValue(); assertThat(updateProcessMessage.getModelPlotConfig(), equalTo(modelConfig)); @@ -527,8 +506,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testJobHasActiveAutodetectProcess() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); assertFalse(manager.jobHasActiveAutodetectProcess(jobTask)); @@ -545,8 +523,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testKillKillsAutodetectProcess() throws IOException { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); assertFalse(manager.jobHasActiveAutodetectProcess(jobTask)); @@ -559,12 +536,11 @@ public class AutodetectProcessManagerTests extends ESTestCase { manager.killAllProcessesOnThisNode(); - verify(communicator).killProcess(false, false, true); + verify(autodetectCommunicator).killProcess(false, false, true); } public void testKillingAMissingJobFinishesTheTask() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); - AutodetectProcessManager manager = createManager(communicator); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); @@ -574,14 +550,13 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testProcessData_GivenStateNotOpened() { - AutodetectCommunicator communicator = mock(AutodetectCommunicator.class); doAnswer(invocationOnMock -> { @SuppressWarnings("unchecked") BiConsumer handler = (BiConsumer) invocationOnMock.getArguments()[4]; handler.accept(new DataCounts("foo"), null); return null; - }).when(communicator).writeToJob(any(), any(), any(), any(), any()); - AutodetectProcessManager manager = createManager(communicator); + }).when(autodetectCommunicator).writeToJob(any(), any(), any(), any(), any()); + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); @@ -595,8 +570,6 @@ public class AutodetectProcessManagerTests extends ESTestCase { } public void testCreate_notEnoughThreads() throws IOException { - Client client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ExecutorService executorService = mock(ExecutorService.class); doThrow(new EsRejectedExecutionException("")).when(executorService).submit(any(Runnable.class)); @@ -611,11 +584,9 @@ public class AutodetectProcessManagerTests extends ESTestCase { }).when(jobManager).getJob(eq("my_id"), any()); AutodetectProcess autodetectProcess = mock(AutodetectProcess.class); - AutodetectProcessFactory autodetectProcessFactory = - (j, autodetectParams, e, onProcessCrash) -> autodetectProcess; - AutodetectProcessManager manager = new AutodetectProcessManager(environment, Settings.EMPTY, - client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory, - normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService); + autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess; + AutodetectProcessManager manager = createSpyManager(); + doCallRealMethod().when(manager).create(any(), any(), any(), any()); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("my_id"); @@ -675,86 +646,7 @@ public class AutodetectProcessManagerTests extends ESTestCase { verifyNoMoreInteractions(auditor); } - public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() { - final ThreadPool threadPool = new TestThreadPool("testAutodetectWorkerExecutorServiceDoesNotSwallowErrors"); - try { - final AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext()); - if (randomBoolean()) { - executor.submit(() -> { - throw new Error("future error"); - }); - } else { - executor.execute(() -> { - throw new Error("future error"); - }); - } - final Error e = expectThrows(Error.class, () -> executor.start()); - assertThat(e.getMessage(), containsString("future error")); - } finally { - ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); - } - } - - public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() { - AutodetectProcessManager.AutodetectWorkerExecutorService executor = - new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY)); - - threadPool.generic().execute(() -> executor.start()); - executor.shutdown(); - expectThrows(EsRejectedExecutionException.class, () -> executor.execute(() -> {})); - } - - public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() - throws InterruptedException, ExecutionException { - AutodetectProcessManager.AutodetectWorkerExecutorService executor = - new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY)); - - CountDownLatch latch = new CountDownLatch(1); - - Future executorFinished = threadPool.generic().submit(() -> executor.start()); - - // run a task that will block while the others are queued up - executor.execute(() -> { - try { - latch.await(); - } catch (InterruptedException e) { - } - }); - - AtomicBoolean runnableShouldNotBeCalled = new AtomicBoolean(false); - executor.execute(() -> runnableShouldNotBeCalled.set(true)); - - AtomicInteger onFailureCallCount = new AtomicInteger(); - AtomicInteger doRunCallCount = new AtomicInteger(); - for (int i=0; i<2; i++) { - executor.execute(new AbstractRunnable() { - @Override - public void onFailure(Exception e) { - onFailureCallCount.incrementAndGet(); - } - - @Override - protected void doRun() { - doRunCallCount.incrementAndGet(); - } - }); - } - - // now shutdown - executor.shutdown(); - latch.countDown(); - executorFinished.get(); - - assertFalse(runnableShouldNotBeCalled.get()); - // the AbstractRunnables should have had their callbacks called - assertEquals(2, onFailureCallCount.get()); - assertEquals(0, doRunCallCount.get()); - } - private AutodetectProcessManager createNonSpyManager(String jobId) { - Client client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); ExecutorService executorService = mock(ExecutorService.class); when(threadPool.executor(anyString())).thenReturn(executorService); when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(mock(ThreadPool.Cancellable.class)); @@ -766,11 +658,8 @@ public class AutodetectProcessManagerTests extends ESTestCase { }).when(jobManager).getJob(eq(jobId), any()); AutodetectProcess autodetectProcess = mock(AutodetectProcess.class); - AutodetectProcessFactory autodetectProcessFactory = - (j, autodetectParams, e, onProcessCrash) -> autodetectProcess; - return new AutodetectProcessManager(environment, Settings.EMPTY, client, threadPool, jobManager, - jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory, - normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService); + autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess; + return createManager(Settings.EMPTY); } private AutodetectParams buildAutodetectParams() { @@ -783,27 +672,25 @@ public class AutodetectProcessManagerTests extends ESTestCase { .build(); } - private AutodetectProcessManager createManager(AutodetectCommunicator communicator) { - Client client = mock(Client.class); - return createManager(communicator, client); + private AutodetectProcessManager createSpyManager() { + return createSpyManager(Settings.EMPTY); } - private AutodetectProcessManager createManager(AutodetectCommunicator communicator, Client client) { - ThreadPool threadPool = mock(ThreadPool.class); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService()); - AutodetectProcessFactory autodetectProcessFactory = mock(AutodetectProcessFactory.class); - AutodetectProcessManager manager = new AutodetectProcessManager(environment, Settings.EMPTY, - client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, - autodetectProcessFactory, normalizerFactory, - new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService); + private AutodetectProcessManager createSpyManager(Settings settings) { + AutodetectProcessManager manager = createManager(settings); manager = spy(manager); - doReturn(communicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), any()); + doReturn(autodetectCommunicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), any()); return manager; } - private AutodetectProcessManager createManagerAndCallProcessData(AutodetectCommunicator communicator, String jobId) { - AutodetectProcessManager manager = createManager(communicator); + private AutodetectProcessManager createManager(Settings settings) { + return new AutodetectProcessManager(environment, settings, + client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, + autodetectFactory, normalizerFactory, + new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService); + } + private AutodetectProcessManager createSpyManagerAndCallProcessData(String jobId) { + AutodetectProcessManager manager = createSpyManager(); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn(jobId); manager.openJob(jobTask, clusterState, (e, b) -> {}); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorServiceTests.java new file mode 100644 index 00000000000..4e9afd38c99 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorServiceTests.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.job.process.autodetect; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.containsString; + +public class AutodetectWorkerExecutorServiceTests extends ESTestCase { + + private ThreadPool threadPool = new TestThreadPool("AutodetectWorkerExecutorServiceTests"); + + @After + public void stopThreadPool() { + terminate(threadPool); + } + + public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() { + AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY)); + + threadPool.generic().execute(() -> executor.start()); + executor.shutdown(); + expectThrows(EsRejectedExecutionException.class, () -> executor.execute(() -> {})); + } + + public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() throws Exception { + AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY)); + + CountDownLatch latch = new CountDownLatch(1); + + Future executorFinished = threadPool.generic().submit(() -> executor.start()); + + // run a task that will block while the others are queued up + executor.execute(() -> { + try { + latch.await(); + } catch (InterruptedException e) { + } + }); + + AtomicBoolean runnableShouldNotBeCalled = new AtomicBoolean(false); + executor.execute(() -> runnableShouldNotBeCalled.set(true)); + + AtomicInteger onFailureCallCount = new AtomicInteger(); + AtomicInteger doRunCallCount = new AtomicInteger(); + for (int i=0; i<2; i++) { + executor.execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + onFailureCallCount.incrementAndGet(); + } + + @Override + protected void doRun() { + doRunCallCount.incrementAndGet(); + } + }); + } + + // now shutdown + executor.shutdown(); + latch.countDown(); + executorFinished.get(); + + assertFalse(runnableShouldNotBeCalled.get()); + // the AbstractRunnables should have had their callbacks called + assertEquals(2, onFailureCallCount.get()); + assertEquals(0, doRunCallCount.get()); + } + + public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() { + AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext()); + if (randomBoolean()) { + executor.submit(() -> { + throw new Error("future error"); + }); + } else { + executor.execute(() -> { + throw new Error("future error"); + }); + } + Error e = expectThrows(Error.class, () -> executor.start()); + assertThat(e.getMessage(), containsString("future error")); + } +}