[ML] Wait for model loaded and cached in ModelLoadingServiceTests (#56014)
Fixes test by exposing the method ModelLoadingService::addModelLoadedListener() so that the test class can be notified when a model is loaded which happens in a background thread
This commit is contained in:
parent
317d9fb88f
commit
c204353249
|
@ -324,7 +324,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
|
||||
// Populate loadingListeners key so we know that we are currently loading the model
|
||||
for (String modelId : allReferencedModelKeys) {
|
||||
loadingListeners.put(modelId, new ArrayDeque<>());
|
||||
loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>());
|
||||
}
|
||||
} // synchronized (loadingListeners)
|
||||
if (logger.isTraceEnabled()) {
|
||||
|
@ -420,4 +420,25 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a listener for notification when a model is loaded.
|
||||
*
|
||||
* This method is primarily intended for testing (hence package private)
|
||||
* and shouldn't be required outside of testing.
|
||||
*
|
||||
* @param modelId Model Id
|
||||
* @param modelLoadedListener To be notified
|
||||
*/
|
||||
void addModelLoadedListener(String modelId, ActionListener<Model> modelLoadedListener) {
|
||||
synchronized (loadingListeners) {
|
||||
loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
|
||||
if (listenerQueue == null) {
|
||||
return addFluently(new ArrayDeque<>(), modelLoadedListener);
|
||||
} else {
|
||||
return addFluently(listenerQueue, modelLoadedListener);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,9 +50,13 @@ import org.mockito.ArgumentMatcher;
|
|||
import java.io.IOException;
|
||||
import java.net.InetAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
@ -145,7 +149,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/55251")
|
||||
public void testMaxCachedLimitReached() throws Exception {
|
||||
String model1 = "test-cached-limit-load-model-1";
|
||||
String model2 = "test-cached-limit-load-model-2";
|
||||
|
@ -164,6 +167,12 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
|
||||
"test-node");
|
||||
|
||||
// We want to be notified when the models are loaded which happens in a background thread
|
||||
ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds));
|
||||
for (String modelId : modelIds) {
|
||||
modelLoadingService.addModelLoadedListener(modelId, loadedTracker.actionListener());
|
||||
}
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
// Should have been loaded from the cluster change event but it is unknown in what order
|
||||
|
@ -175,6 +184,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||
});
|
||||
|
||||
// all models loaded put in the cache
|
||||
assertBusy(() -> assertTrue(loadedTracker.allModelsLoaded()), 2, TimeUnit.SECONDS);
|
||||
|
||||
for(int i = 0; i < 10; i++) {
|
||||
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
||||
String model = modelIds[i%2];
|
||||
|
@ -426,4 +438,28 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
||||
}
|
||||
}
|
||||
|
||||
private static class ModelLoadedTracker {
|
||||
private final Set<String> expectedModelIds;
|
||||
|
||||
ModelLoadedTracker(Collection<String> expectedModelIds) {
|
||||
this.expectedModelIds = new HashSet<>(expectedModelIds);
|
||||
}
|
||||
|
||||
private synchronized boolean allModelsLoaded() {
|
||||
return expectedModelIds.isEmpty();
|
||||
}
|
||||
|
||||
private synchronized void onModelLoaded(Model model) {
|
||||
expectedModelIds.remove(model.getModelId());
|
||||
}
|
||||
|
||||
private void onFailure(Exception e) {
|
||||
fail(e.getMessage());
|
||||
}
|
||||
|
||||
ActionListener<Model> actionListener() {
|
||||
return ActionListener.wrap(this::onModelLoaded, this::onFailure);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue