diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index e5376bccb17..6b8e9e44f59 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -435,10 +435,10 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu DatafeedManager datafeedManager = new DatafeedManager(threadPool, client, clusterService, datafeedJobBuilder, System::currentTimeMillis, auditor, autodetectProcessManager); this.datafeedManager.set(datafeedManager); - MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager, - autodetectProcessManager); MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider); this.memoryTracker.set(memoryTracker); + MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager, + autodetectProcessManager, memoryTracker); // This object's constructor attaches to the license state, so there's no need to retain another reference to it new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java index 8005912107a..06d9b749e1a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java @@ -9,6 +9,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.ml.datafeed.DatafeedManager; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeControllerHolder; import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; @@ -20,16 +21,14 @@ public class MlLifeCycleService { private final Environment environment; private final DatafeedManager datafeedManager; private final AutodetectProcessManager autodetectProcessManager; - - public MlLifeCycleService(Environment environment, ClusterService clusterService) { - this(environment, clusterService, null, null); - } + private final MlMemoryTracker memoryTracker; public MlLifeCycleService(Environment environment, ClusterService clusterService, DatafeedManager datafeedManager, - AutodetectProcessManager autodetectProcessManager) { + AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker) { this.environment = environment; this.datafeedManager = datafeedManager; this.autodetectProcessManager = autodetectProcessManager; + this.memoryTracker = memoryTracker; clusterService.addLifecycleListener(new LifecycleListener() { @Override public void beforeStop() { @@ -59,5 +58,8 @@ public class MlLifeCycleService { } catch (IOException e) { // We're stopping anyway, so don't let this complicate the shutdown sequence } + if (memoryTracker != null) { + memoryTracker.stop(); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java index 441317bcbe2..50d2515046a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -32,6 +32,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Phaser; import java.util.stream.Collectors; /** @@ -55,6 +56,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener { private final ClusterService clusterService; private final JobManager jobManager; private final JobResultsProvider jobResultsProvider; + private final Phaser stopPhaser; private volatile boolean isMaster; private volatile Instant lastUpdateTime; private volatile Duration reassignmentRecheckInterval; @@ -65,6 +67,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener { this.clusterService = clusterService; this.jobManager = jobManager; this.jobResultsProvider = jobResultsProvider; + this.stopPhaser = new Phaser(1); setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings)); clusterService.addLocalNodeMasterListener(this); clusterService.getClusterSettings().addSettingsUpdateConsumer( @@ -89,6 +92,23 @@ public class MlMemoryTracker implements LocalNodeMasterListener { lastUpdateTime = null; } + /** + * Wait for all outstanding searches to complete. + * After returning, no new searches can be started. + */ + public void stop() { + logger.trace("ML memory tracker stop called"); + // We never terminate the phaser + assert stopPhaser.isTerminated() == false; + // If there are no registered parties or no unarrived parties then there is a flaw + // in the register/arrive/unregister logic in another method that uses the phaser + assert stopPhaser.getRegisteredParties() > 0; + assert stopPhaser.getUnarrivedParties() > 0; + stopPhaser.arriveAndAwaitAdvance(); + assert stopPhaser.getPhase() > 0; + logger.debug("ML memory tracker stopped"); + } + @Override public String executorName() { return MachineLearning.UTILITY_THREAD_POOL_NAME; @@ -146,13 +166,13 @@ public class MlMemoryTracker implements LocalNodeMasterListener { try { ActionListener listener = ActionListener.wrap( aVoid -> logger.trace("Job memory requirement refresh request completed successfully"), - e -> logger.error("Failed to refresh job memory requirements", e) + e -> logger.warn("Failed to refresh job memory requirements", e) ); threadPool.executor(executorName()).execute( () -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), listener)); return true; } catch (EsRejectedExecutionException e) { - logger.debug("Couldn't schedule ML memory update - node might be shutting down", e); + logger.warn("Couldn't schedule ML memory update - node might be shutting down", e); } } @@ -246,25 +266,43 @@ public class MlMemoryTracker implements LocalNodeMasterListener { return; } + // The phaser prevents searches being started after the memory tracker's stop() method has returned + if (stopPhaser.register() != 0) { + // Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction + stopPhaser.arriveAndDeregister(); + listener.onFailure(new EsRejectedExecutionException("Couldn't run ML memory update - node is shutting down")); + return; + } + ActionListener phaserListener = ActionListener.wrap( + r -> { + stopPhaser.arriveAndDeregister(); + listener.onResponse(r); + }, + e -> { + stopPhaser.arriveAndDeregister(); + listener.onFailure(e); + } + ); + try { jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null, establishedModelMemoryBytes -> { if (establishedModelMemoryBytes <= 0L) { - setJobMemoryToLimit(jobId, listener); + setJobMemoryToLimit(jobId, phaserListener); } else { Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); memoryRequirementByJob.put(jobId, memoryRequirementBytes); - listener.onResponse(memoryRequirementBytes); + phaserListener.onResponse(memoryRequirementBytes); } }, e -> { logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); - setJobMemoryToLimit(jobId, listener); + setJobMemoryToLimit(jobId, phaserListener); } ); } catch (Exception e) { logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); - setJobMemoryToLimit(jobId, listener); + setJobMemoryToLimit(jobId, phaserListener); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java index 3e54994ac04..1dd2ba923ef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.persistent.PersistentTasksClusterService; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.test.ESTestCase; @@ -29,6 +30,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.anyString; @@ -157,6 +159,19 @@ public class MlMemoryTrackerTests extends ESTestCase { assertNull(memoryTracker.getJobMemoryRequirement(jobId)); } + public void testStop() { + + memoryTracker.onMaster(); + memoryTracker.stop(); + + AtomicReference exception = new AtomicReference<>(); + memoryTracker.refreshJobMemory("job", ActionListener.wrap(ESTestCase::assertNull, exception::set)); + + assertNotNull(exception.get()); + assertThat(exception.get(), instanceOf(EsRejectedExecutionException.class)); + assertEquals("Couldn't run ML memory update - node is shutting down", exception.get().getMessage()); + } + private PersistentTasksCustomMetaData.PersistentTask makeTestTask(String jobId) { return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId), 0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT);