[7.1][ML] Refactor autodetect service into its own class (#41378) (#41409)

This also improves aims to improve the corresponding unit tests
with regard to readability and maintainability.
This commit is contained in:
Dimitris Athanasiou 2019-04-22 17:42:13 +03:00 committed by GitHub
parent d2a418152d
commit eb2295ac81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 289 additions and 284 deletions

View File

@ -16,14 +16,12 @@ import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@ -77,20 +75,13 @@ import java.io.InputStream;
import java.nio.file.Path;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@ -791,99 +782,4 @@ public class AutodetectProcessManager implements ClusterStateListener {
upgradeInProgress = MlMetadata.getMlMetadata(event.state()).isUpgradeMode();
}
/*
* The autodetect native process can only handle a single operation at a time. In order to guarantee that, all
* operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each
* operation at a time.
*/
static class AutodetectWorkerExecutorService extends AbstractExecutorService {
private final ThreadContext contextHolder;
private final CountDownLatch awaitTermination = new CountDownLatch(1);
private final BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(100);
private volatile boolean running = true;
@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
AutodetectWorkerExecutorService(ThreadContext contextHolder) {
this.contextHolder = contextHolder;
}
@Override
public void shutdown() {
running = false;
}
@Override
public List<Runnable> shutdownNow() {
throw new UnsupportedOperationException("not supported");
}
@Override
public boolean isShutdown() {
return running == false;
}
@Override
public boolean isTerminated() {
return awaitTermination.getCount() == 0;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return awaitTermination.await(timeout, unit);
}
@Override
public synchronized void execute(Runnable command) {
if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true);
if (command instanceof AbstractRunnable) {
((AbstractRunnable) command).onRejection(rejected);
} else {
throw rejected;
}
}
boolean added = queue.offer(contextHolder.preserveContext(command));
if (added == false) {
throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS);
}
}
void start() {
try {
while (running) {
Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS);
if (runnable != null) {
try {
runnable.run();
} catch (Exception e) {
logger.error("error handling job operation", e);
}
EsExecutors.rethrowErrors(contextHolder.unwrap(runnable));
}
}
synchronized (this) {
// if shutdown with tasks pending notify the handlers
if (queue.isEmpty() == false) {
List<Runnable> notExecuted = new ArrayList<>();
queue.drainTo(notExecuted);
for (Runnable runnable : notExecuted) {
if (runnable instanceof AbstractRunnable) {
((AbstractRunnable) runnable).onRejection(
new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true));
}
}
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
awaitTermination.countDown();
}
}
}
}

View File

@ -0,0 +1,122 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.job.process.autodetect;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.RestStatus;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
/*
* The autodetect native process can only handle a single operation at a time. In order to guarantee that, all
* operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each
* operation at a time.
*/
class AutodetectWorkerExecutorService extends AbstractExecutorService {
private static final Logger logger = LogManager.getLogger(AutodetectWorkerExecutorService.class);
private final ThreadContext contextHolder;
private final CountDownLatch awaitTermination = new CountDownLatch(1);
private final BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(100);
private volatile boolean running = true;
@SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
AutodetectWorkerExecutorService(ThreadContext contextHolder) {
this.contextHolder = contextHolder;
}
@Override
public void shutdown() {
running = false;
}
@Override
public List<Runnable> shutdownNow() {
throw new UnsupportedOperationException("not supported");
}
@Override
public boolean isShutdown() {
return running == false;
}
@Override
public boolean isTerminated() {
return awaitTermination.getCount() == 0;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return awaitTermination.await(timeout, unit);
}
@Override
public synchronized void execute(Runnable command) {
if (isShutdown()) {
EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true);
if (command instanceof AbstractRunnable) {
((AbstractRunnable) command).onRejection(rejected);
} else {
throw rejected;
}
}
boolean added = queue.offer(contextHolder.preserveContext(command));
if (added == false) {
throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS);
}
}
void start() {
try {
while (running) {
Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS);
if (runnable != null) {
try {
runnable.run();
} catch (Exception e) {
logger.error("error handling job operation", e);
}
EsExecutors.rethrowErrors(contextHolder.unwrap(runnable));
}
}
synchronized (this) {
// if shutdown with tasks pending notify the handlers
if (queue.isEmpty() == false) {
List<Runnable> notExecuted = new ArrayList<>();
queue.drainTo(notExecuted);
for (Runnable runnable : notExecuted) {
if (runnable instanceof AbstractRunnable) {
((AbstractRunnable) runnable).onRejection(
new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true));
}
}
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
awaitTermination.countDown();
}
}
}

View File

@ -15,7 +15,6 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -26,7 +25,6 @@ import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.index.analysis.AnalysisRegistry;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@ -50,17 +48,14 @@ import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzerTests
import org.elasticsearch.xpack.ml.job.persistence.JobDataCountsPersister;
import org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister;
import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager.AutodetectWorkerExecutorService;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.TimeRange;
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
import org.elasticsearch.xpack.ml.notifications.Auditor;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import java.io.ByteArrayInputStream;
import java.io.IOException;
@ -76,11 +71,9 @@ import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
@ -93,7 +86,6 @@ import static org.elasticsearch.mock.orig.Mockito.times;
import static org.elasticsearch.mock.orig.Mockito.verify;
import static org.elasticsearch.mock.orig.Mockito.verifyNoMoreInteractions;
import static org.elasticsearch.mock.orig.Mockito.when;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
@ -103,6 +95,7 @@ import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
@ -115,11 +108,15 @@ import static org.mockito.Mockito.spy;
public class AutodetectProcessManagerTests extends ESTestCase {
private Environment environment;
private Client client;
private ThreadPool threadPool;
private AnalysisRegistry analysisRegistry;
private JobManager jobManager;
private JobResultsProvider jobResultsProvider;
private JobResultsPersister jobResultsPersister;
private JobDataCountsPersister jobDataCountsPersister;
private AutodetectCommunicator autodetectCommunicator;
private AutodetectProcessFactory autodetectFactory;
private NormalizerFactory normalizerFactory;
private Auditor auditor;
private ClusterState clusterState;
@ -131,18 +128,24 @@ public class AutodetectProcessManagerTests extends ESTestCase {
private Quantiles quantiles = new Quantiles("foo", new Date(), "state");
private Set<MlFilter> filters = new HashSet<>();
private ThreadPool threadPool;
@Before
public void setup() throws Exception {
Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir()).build();
environment = TestEnvironment.newEnvironment(settings);
client = mock(Client.class);
threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService());
analysisRegistry = CategorizationAnalyzerTests.buildTestAnalysisRegistry(environment);
jobManager = mock(JobManager.class);
jobResultsProvider = mock(JobResultsProvider.class);
jobResultsPersister = mock(JobResultsPersister.class);
when(jobResultsPersister.bulkPersisterBuilder(any())).thenReturn(mock(JobResultsPersister.Builder.class));
jobDataCountsPersister = mock(JobDataCountsPersister.class);
autodetectCommunicator = mock(AutodetectCommunicator.class);
autodetectFactory = mock(AutodetectProcessFactory.class);
normalizerFactory = mock(NormalizerFactory.class);
auditor = mock(Auditor.class);
clusterService = mock(ClusterService.class);
@ -170,25 +173,16 @@ public class AutodetectProcessManagerTests extends ESTestCase {
handler.accept(buildAutodetectParams());
return null;
}).when(jobResultsProvider).getAutodetectParams(any(), any(), any());
threadPool = new TestThreadPool("AutodetectProcessManagerTests");
}
@After
public void stopThreadPool() {
terminate(threadPool);
}
public void testOpenJob() {
Client client = mock(Client.class);
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
ActionListener<Job> listener = (ActionListener<Job>) invocationOnMock.getArguments()[1];
listener.onResponse(createJobDetails("foo"));
return null;
}).when(jobManager).getJob(eq("foo"), any());
AutodetectProcessManager manager = createManager(communicator, client);
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
@ -200,8 +194,6 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testOpenJob_withoutVersion() {
Client client = mock(Client.class);
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
Job.Builder jobBuilder = new Job.Builder(createJobDetails("no_version"));
jobBuilder.setJobVersion(null);
Job job = jobBuilder.build();
@ -214,7 +206,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
return null;
}).when(jobManager).getJob(eq(job.getId()), any());
AutodetectProcessManager manager = createManager(communicator, client);
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn(job.getId());
AtomicReference<Exception> errorHolder = new AtomicReference<>();
@ -235,25 +227,22 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}).when(jobManager).getJob(eq(jobId), any());
}
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ThreadPool.Cancellable cancellable = mock(ThreadPool.Cancellable.class);
when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable);
ExecutorService executorService = mock(ExecutorService.class);
Future<?> future = mock(Future.class);
when(executorService.submit(any(Callable.class))).thenReturn(future);
when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService());
AutodetectProcess autodetectProcess = mock(AutodetectProcess.class);
when(autodetectProcess.isProcessAlive()).thenReturn(true);
when(autodetectProcess.readAutodetectResults()).thenReturn(Collections.emptyIterator());
AutodetectProcessFactory autodetectProcessFactory =
(j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
Settings.Builder settings = Settings.builder();
settings.put(MachineLearning.MAX_OPEN_JOBS_PER_NODE.getKey(), 3);
AutodetectProcessManager manager = spy(new AutodetectProcessManager(environment, settings.build(), client, threadPool,
jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory,
normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService));
AutodetectProcessManager manager = createSpyManager(settings.build());
doCallRealMethod().when(manager).create(any(), any(), any(), any());
ExecutorService executorService = mock(ExecutorService.class);
Future<?> future = mock(Future.class);
when(executorService.submit(any(Callable.class))).thenReturn(future);
doReturn(executorService).when(manager).createAutodetectExecutorService(any());
doAnswer(invocationOnMock -> {
@ -293,8 +282,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testProcessData() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
assertEquals(0, manager.numberOfOpenJobs());
JobTask jobTask = mock(JobTask.class);
@ -307,8 +295,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testProcessDataThrowsElasticsearchStatusException_onIoException() {
AutodetectCommunicator communicator = Mockito.mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
DataLoadParams params = mock(DataLoadParams.class);
InputStream inputStream = createInputStream("");
@ -318,7 +305,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
BiConsumer<DataCounts, Exception> handler = (BiConsumer<DataCounts, Exception>) invocationOnMock.getArguments()[4];
handler.accept(null, new IOException("blah"));
return null;
}).when(communicator).writeToJob(eq(inputStream), same(analysisRegistry), same(xContentType), eq(params), any());
}).when(autodetectCommunicator).writeToJob(eq(inputStream), same(analysisRegistry), same(xContentType), eq(params), any());
JobTask jobTask = mock(JobTask.class);
@ -330,8 +317,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testCloseJob() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
assertEquals(0, manager.numberOfOpenJobs());
JobTask jobTask = mock(JobTask.class);
@ -350,7 +336,6 @@ public class AutodetectProcessManagerTests extends ESTestCase {
// interleaved in the AutodetectProcessManager.close() call
@TestLogging("org.elasticsearch.xpack.ml.job.process.autodetect:DEBUG")
public void testCanCloseClosingJob() throws Exception {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AtomicInteger numberOfCommunicatorCloses = new AtomicInteger(0);
doAnswer(invocationOnMock -> {
numberOfCommunicatorCloses.incrementAndGet();
@ -358,8 +343,8 @@ public class AutodetectProcessManagerTests extends ESTestCase {
// the middle of the AutodetectProcessManager.close() method
Thread.yield();
return null;
}).when(communicator).close(anyBoolean(), anyString());
AutodetectProcessManager manager = createManager(communicator);
}).when(autodetectCommunicator).close(anyBoolean(), anyString());
AutodetectProcessManager manager = createSpyManager();
assertEquals(0, manager.numberOfOpenJobs());
JobTask jobTask = mock(JobTask.class);
@ -395,19 +380,18 @@ public class AutodetectProcessManagerTests extends ESTestCase {
CountDownLatch closeStartedLatch = new CountDownLatch(1);
CountDownLatch killLatch = new CountDownLatch(1);
CountDownLatch closeInterruptedLatch = new CountDownLatch(1);
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
doAnswer(invocationOnMock -> {
closeStartedLatch.countDown();
if (killLatch.await(3, TimeUnit.SECONDS)) {
closeInterruptedLatch.countDown();
}
return null;
}).when(communicator).close(anyBoolean(), anyString());
}).when(autodetectCommunicator).close(anyBoolean(), anyString());
doAnswer(invocationOnMock -> {
killLatch.countDown();
return null;
}).when(communicator).killProcess(anyBoolean(), anyBoolean(), anyBoolean());
AutodetectProcessManager manager = createManager(communicator);
}).when(autodetectCommunicator).killProcess(anyBoolean(), anyBoolean(), anyBoolean());
AutodetectProcessManager manager = createSpyManager();
assertEquals(0, manager.numberOfOpenJobs());
JobTask jobTask = mock(JobTask.class);
@ -433,8 +417,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testBucketResetMessageIsSent() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
XContentType xContentType = randomFrom(XContentType.values());
DataLoadParams params = new DataLoadParams(TimeRange.builder().startTime("1000").endTime("2000").build(), Optional.empty());
@ -443,12 +426,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
when(jobTask.getJobId()).thenReturn("foo");
manager.openJob(jobTask, clusterState, (e, b) -> {});
manager.processData(jobTask, analysisRegistry, inputStream, xContentType, params, (dataCounts1, e) -> {});
verify(communicator).writeToJob(same(inputStream), same(analysisRegistry), same(xContentType), same(params), any());
verify(autodetectCommunicator).writeToJob(same(inputStream), same(analysisRegistry), same(xContentType), same(params), any());
}
public void testFlush() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
@ -460,12 +442,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
FlushJobParams params = FlushJobParams.builder().build();
manager.flushJob(jobTask, params, ActionListener.wrap(flushAcknowledgement -> {}, e -> fail(e.getMessage())));
verify(communicator).flushJob(same(params), any());
verify(autodetectCommunicator).flushJob(same(params), any());
}
public void testFlushThrows() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManagerAndCallProcessData(communicator, "foo");
AutodetectProcessManager manager = createSpyManagerAndCallProcessData("foo");
FlushJobParams params = FlushJobParams.builder().build();
doAnswer(invocationOnMock -> {
@ -473,7 +454,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
BiConsumer<Void, Exception> handler = (BiConsumer<Void, Exception>) invocationOnMock.getArguments()[1];
handler.accept(null, new IOException("blah"));
return null;
}).when(communicator).flushJob(same(params), any());
}).when(autodetectCommunicator).flushJob(same(params), any());
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
@ -483,12 +464,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testCloseThrows() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
// let the communicator throw, simulating a problem with the underlying
// autodetect, e.g. a crash
doThrow(Exception.class).when(communicator).close(anyBoolean(), anyString());
doThrow(Exception.class).when(autodetectCommunicator).close(anyBoolean(), anyString());
// create a jobtask
JobTask jobTask = mock(JobTask.class);
@ -507,8 +487,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testWriteUpdateProcessMessage() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManagerAndCallProcessData(communicator, "foo");
AutodetectProcessManager manager = createSpyManagerAndCallProcessData("foo");
ModelPlotConfig modelConfig = mock(ModelPlotConfig.class);
List<DetectionRule> rules = Collections.singletonList(mock(DetectionRule.class));
List<JobUpdate.DetectorUpdate> detectorUpdates = Collections.singletonList(new JobUpdate.DetectorUpdate(2, null, rules));
@ -519,7 +498,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
manager.writeUpdateProcessMessage(jobTask, updateParams, e -> {});
ArgumentCaptor<UpdateProcessMessage> captor = ArgumentCaptor.forClass(UpdateProcessMessage.class);
verify(communicator).writeUpdateProcessMessage(captor.capture(), any());
verify(autodetectCommunicator).writeUpdateProcessMessage(captor.capture(), any());
UpdateProcessMessage updateProcessMessage = captor.getValue();
assertThat(updateProcessMessage.getModelPlotConfig(), equalTo(modelConfig));
@ -527,8 +506,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testJobHasActiveAutodetectProcess() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
assertFalse(manager.jobHasActiveAutodetectProcess(jobTask));
@ -545,8 +523,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testKillKillsAutodetectProcess() throws IOException {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
assertFalse(manager.jobHasActiveAutodetectProcess(jobTask));
@ -559,12 +536,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
manager.killAllProcessesOnThisNode();
verify(communicator).killProcess(false, false, true);
verify(autodetectCommunicator).killProcess(false, false, true);
}
public void testKillingAMissingJobFinishesTheTask() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
AutodetectProcessManager manager = createManager(communicator);
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
@ -574,14 +550,13 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testProcessData_GivenStateNotOpened() {
AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
BiConsumer<DataCounts, Exception> handler = (BiConsumer<DataCounts, Exception>) invocationOnMock.getArguments()[4];
handler.accept(new DataCounts("foo"), null);
return null;
}).when(communicator).writeToJob(any(), any(), any(), any(), any());
AutodetectProcessManager manager = createManager(communicator);
}).when(autodetectCommunicator).writeToJob(any(), any(), any(), any(), any());
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("foo");
@ -595,8 +570,6 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}
public void testCreate_notEnoughThreads() throws IOException {
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ExecutorService executorService = mock(ExecutorService.class);
doThrow(new EsRejectedExecutionException("")).when(executorService).submit(any(Runnable.class));
@ -611,11 +584,9 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}).when(jobManager).getJob(eq("my_id"), any());
AutodetectProcess autodetectProcess = mock(AutodetectProcess.class);
AutodetectProcessFactory autodetectProcessFactory =
(j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
AutodetectProcessManager manager = new AutodetectProcessManager(environment, Settings.EMPTY,
client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory,
normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
AutodetectProcessManager manager = createSpyManager();
doCallRealMethod().when(manager).create(any(), any(), any(), any());
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn("my_id");
@ -675,86 +646,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
verifyNoMoreInteractions(auditor);
}
public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
final ThreadPool threadPool = new TestThreadPool("testAutodetectWorkerExecutorServiceDoesNotSwallowErrors");
try {
final AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
if (randomBoolean()) {
executor.submit(() -> {
throw new Error("future error");
});
} else {
executor.execute(() -> {
throw new Error("future error");
});
}
final Error e = expectThrows(Error.class, () -> executor.start());
assertThat(e.getMessage(), containsString("future error"));
} finally {
ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
}
}
public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
AutodetectProcessManager.AutodetectWorkerExecutorService executor =
new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
threadPool.generic().execute(() -> executor.start());
executor.shutdown();
expectThrows(EsRejectedExecutionException.class, () -> executor.execute(() -> {}));
}
public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown()
throws InterruptedException, ExecutionException {
AutodetectProcessManager.AutodetectWorkerExecutorService executor =
new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
CountDownLatch latch = new CountDownLatch(1);
Future<?> executorFinished = threadPool.generic().submit(() -> executor.start());
// run a task that will block while the others are queued up
executor.execute(() -> {
try {
latch.await();
} catch (InterruptedException e) {
}
});
AtomicBoolean runnableShouldNotBeCalled = new AtomicBoolean(false);
executor.execute(() -> runnableShouldNotBeCalled.set(true));
AtomicInteger onFailureCallCount = new AtomicInteger();
AtomicInteger doRunCallCount = new AtomicInteger();
for (int i=0; i<2; i++) {
executor.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
onFailureCallCount.incrementAndGet();
}
@Override
protected void doRun() {
doRunCallCount.incrementAndGet();
}
});
}
// now shutdown
executor.shutdown();
latch.countDown();
executorFinished.get();
assertFalse(runnableShouldNotBeCalled.get());
// the AbstractRunnables should have had their callbacks called
assertEquals(2, onFailureCallCount.get());
assertEquals(0, doRunCallCount.get());
}
private AutodetectProcessManager createNonSpyManager(String jobId) {
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ExecutorService executorService = mock(ExecutorService.class);
when(threadPool.executor(anyString())).thenReturn(executorService);
when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(mock(ThreadPool.Cancellable.class));
@ -766,11 +658,8 @@ public class AutodetectProcessManagerTests extends ESTestCase {
}).when(jobManager).getJob(eq(jobId), any());
AutodetectProcess autodetectProcess = mock(AutodetectProcess.class);
AutodetectProcessFactory autodetectProcessFactory =
(j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
return new AutodetectProcessManager(environment, Settings.EMPTY, client, threadPool, jobManager,
jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory,
normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
return createManager(Settings.EMPTY);
}
private AutodetectParams buildAutodetectParams() {
@ -783,27 +672,25 @@ public class AutodetectProcessManagerTests extends ESTestCase {
.build();
}
private AutodetectProcessManager createManager(AutodetectCommunicator communicator) {
Client client = mock(Client.class);
return createManager(communicator, client);
private AutodetectProcessManager createSpyManager() {
return createSpyManager(Settings.EMPTY);
}
private AutodetectProcessManager createManager(AutodetectCommunicator communicator, Client client) {
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService());
AutodetectProcessFactory autodetectProcessFactory = mock(AutodetectProcessFactory.class);
AutodetectProcessManager manager = new AutodetectProcessManager(environment, Settings.EMPTY,
client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister,
autodetectProcessFactory, normalizerFactory,
new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
private AutodetectProcessManager createSpyManager(Settings settings) {
AutodetectProcessManager manager = createManager(settings);
manager = spy(manager);
doReturn(communicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), any());
doReturn(autodetectCommunicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), any());
return manager;
}
private AutodetectProcessManager createManagerAndCallProcessData(AutodetectCommunicator communicator, String jobId) {
AutodetectProcessManager manager = createManager(communicator);
private AutodetectProcessManager createManager(Settings settings) {
return new AutodetectProcessManager(environment, settings,
client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister,
autodetectFactory, normalizerFactory,
new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
}
private AutodetectProcessManager createSpyManagerAndCallProcessData(String jobId) {
AutodetectProcessManager manager = createSpyManager();
JobTask jobTask = mock(JobTask.class);
when(jobTask.getJobId()).thenReturn(jobId);
manager.openJob(jobTask, clusterState, (e, b) -> {});

View File

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.job.process.autodetect;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.Matchers.containsString;
public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
private ThreadPool threadPool = new TestThreadPool("AutodetectWorkerExecutorServiceTests");
@After
public void stopThreadPool() {
terminate(threadPool);
}
public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
threadPool.generic().execute(() -> executor.start());
executor.shutdown();
expectThrows(EsRejectedExecutionException.class, () -> executor.execute(() -> {}));
}
public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() throws Exception {
AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
CountDownLatch latch = new CountDownLatch(1);
Future<?> executorFinished = threadPool.generic().submit(() -> executor.start());
// run a task that will block while the others are queued up
executor.execute(() -> {
try {
latch.await();
} catch (InterruptedException e) {
}
});
AtomicBoolean runnableShouldNotBeCalled = new AtomicBoolean(false);
executor.execute(() -> runnableShouldNotBeCalled.set(true));
AtomicInteger onFailureCallCount = new AtomicInteger();
AtomicInteger doRunCallCount = new AtomicInteger();
for (int i=0; i<2; i++) {
executor.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
onFailureCallCount.incrementAndGet();
}
@Override
protected void doRun() {
doRunCallCount.incrementAndGet();
}
});
}
// now shutdown
executor.shutdown();
latch.countDown();
executorFinished.get();
assertFalse(runnableShouldNotBeCalled.get());
// the AbstractRunnables should have had their callbacks called
assertEquals(2, onFailureCallCount.get());
assertEquals(0, doRunCallCount.get());
}
public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
if (randomBoolean()) {
executor.submit(() -> {
throw new Error("future error");
});
} else {
executor.execute(() -> {
throw new Error("future error");
});
}
Error e = expectThrows(Error.class, () -> executor.start());
assertThat(e.getMessage(), containsString("future error"));
}
}