diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java index afd670a1803..11835cdb8db 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -299,7 +299,18 @@ public class MlMemoryTracker implements LocalNodeMasterListener { } fullRefreshCompletionListeners.clear(); } - }, onCompletion::onFailure); + }, + e -> { + synchronized (fullRefreshCompletionListeners) { + assert fullRefreshCompletionListeners.isEmpty() == false; + for (ActionListener listener : fullRefreshCompletionListeners) { + listener.onFailure(e); + } + // It's critical that we empty out the current listener list on + // error otherwise subsequent retries to refresh will be ignored + fullRefreshCompletionListeners.clear(); + } + }); // persistentTasks will be null if there's never been a persistent task created in this cluster if (persistentTasks == null) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java index 429575902b0..07314948c04 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; @@ -32,11 +33,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static org.hamcrest.CoreMatchers.instanceOf; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; @@ -125,6 +128,66 @@ public class MlMemoryTrackerTests extends ESTestCase { } } + public void testRefreshAllFailure() { + + Map> tasks = new HashMap<>(); + + int numAnomalyDetectorJobTasks = randomIntBetween(2, 5); + for (int i = 1; i <= numAnomalyDetectorJobTasks; ++i) { + String jobId = "job" + i; + PersistentTasksCustomMetaData.PersistentTask task = makeTestAnomalyDetectorTask(jobId); + tasks.put(task.getId(), task); + } + + int numDataFrameAnalyticsTasks = randomIntBetween(2, 5); + for (int i = 1; i <= numDataFrameAnalyticsTasks; ++i) { + String id = "analytics" + i; + PersistentTasksCustomMetaData.PersistentTask task = makeTestDataFrameAnalyticsTask(id); + tasks.put(task.getId(), task); + } + + PersistentTasksCustomMetaData persistentTasks = + new PersistentTasksCustomMetaData(numAnomalyDetectorJobTasks + numDataFrameAnalyticsTasks, tasks); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Consumer listener = (Consumer) invocation.getArguments()[3]; + listener.accept(randomLongBetween(1000, 1000000)); + return null; + }).when(jobResultsProvider).getEstablishedMemoryUsage(anyString(), any(), any(), any(Consumer.class), any()); + + // First run a refresh using a component that calls the onFailure method of the listener + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener> listener = + (ActionListener>) invocation.getArguments()[2]; + listener.onFailure(new IllegalArgumentException("computer says no")); + return null; + }).when(configProvider).getMultiple(anyString(), anyBoolean(), any(ActionListener.class)); + + AtomicBoolean gotErrorResponse = new AtomicBoolean(false); + memoryTracker.refresh(persistentTasks, + ActionListener.wrap(aVoid -> fail("Expected error response"), e -> gotErrorResponse.set(true))); + assertTrue(gotErrorResponse.get()); + + // Now run another refresh using a component that calls the onResponse method of the listener - this + // proves that the ML memory tracker has not been permanently blocked up by the previous failure + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener> listener = + (ActionListener>) invocation.getArguments()[2]; + listener.onResponse(Collections.emptyList()); + return null; + }).when(configProvider).getMultiple(anyString(), anyBoolean(), any(ActionListener.class)); + + AtomicBoolean gotSuccessResponse = new AtomicBoolean(false); + memoryTracker.refresh(persistentTasks, + ActionListener.wrap(aVoid -> gotSuccessResponse.set(true), e -> fail("Expected success response"))); + assertTrue(gotSuccessResponse.get()); + } + public void testRefreshOneAnomalyDetectorJob() { boolean isMaster = randomBoolean();