From c204353249c4795b183eec065353f13eadc8d458 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 30 Apr 2020 13:32:07 +0100 Subject: [PATCH] [ML] Wait for model loaded and cached in ModelLoadingServiceTests (#56014) Fixes test by exposing the method ModelLoadingService::addModelLoadedListener() so that the test class can be notified when a model is loaded which happens in a background thread --- .../loadingservice/ModelLoadingService.java | 23 ++++++++++- .../ModelLoadingServiceTests.java | 38 ++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 842e48ad1e7..5053972e50b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -324,7 +324,7 @@ public class ModelLoadingService implements ClusterStateListener { // Populate loadingListeners key so we know that we are currently loading the model for (String modelId : allReferencedModelKeys) { - loadingListeners.put(modelId, new ArrayDeque<>()); + loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>()); } } // synchronized (loadingListeners) if (logger.isTraceEnabled()) { @@ -420,4 +420,25 @@ public class ModelLoadingService implements ClusterStateListener { throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType); } } + + /** + * Register a listener for notification when a model is loaded. + * + * This method is primarily intended for testing (hence package private) + * and shouldn't be required outside of testing. + * + * @param modelId Model Id + * @param modelLoadedListener To be notified + */ + void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { + synchronized (loadingListeners) { + loadingListeners.compute(modelId, (modelKey, listenerQueue) -> { + if (listenerQueue == null) { + return addFluently(new ArrayDeque<>(), modelLoadedListener); + } else { + return addFluently(listenerQueue, modelLoadedListener); + } + }); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 2db900c1bc7..01395e914cd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -50,9 +50,13 @@ import org.mockito.ArgumentMatcher; import java.io.IOException; import java.net.InetAddress; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; @@ -145,7 +149,6 @@ public class ModelLoadingServiceTests extends ESTestCase { verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any()); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/55251") public void testMaxCachedLimitReached() throws Exception { String model1 = "test-cached-limit-load-model-1"; String model2 = "test-cached-limit-load-model-2"; @@ -164,6 +167,12 @@ public class ModelLoadingServiceTests extends ESTestCase { Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(), "test-node"); + // We want to be notified when the models are loaded which happens in a background thread + ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds)); + for (String modelId : modelIds) { + modelLoadingService.addModelLoadedListener(modelId, loadedTracker.actionListener()); + } + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); // Should have been loaded from the cluster change event but it is unknown in what order @@ -175,6 +184,9 @@ public class ModelLoadingServiceTests extends ESTestCase { verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); }); + // all models loaded put in the cache + assertBusy(() -> assertTrue(loadedTracker.allModelsLoaded()), 2, TimeUnit.SECONDS); + for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; @@ -426,4 +438,28 @@ public class ModelLoadingServiceTests extends ESTestCase { return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); } } + + private static class ModelLoadedTracker { + private final Set expectedModelIds; + + ModelLoadedTracker(Collection expectedModelIds) { + this.expectedModelIds = new HashSet<>(expectedModelIds); + } + + private synchronized boolean allModelsLoaded() { + return expectedModelIds.isEmpty(); + } + + private synchronized void onModelLoaded(Model model) { + expectedModelIds.remove(model.getModelId()); + } + + private void onFailure(Exception e) { + fail(e.getMessage()); + } + + ActionListener actionListener() { + return ActionListener.wrap(this::onModelLoaded, this::onFailure); + } + } }