[ML] Wait for model loaded and cached in ModelLoadingServiceTests (#56014)

Fixes test by exposing the method ModelLoadingService::addModelLoadedListener() 
so that the test class can be notified when a model is loaded which happens in
a background thread
This commit is contained in:
David Kyle 2020-04-30 13:32:07 +01:00 committed by GitHub
parent 317d9fb88f
commit c204353249
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 2 deletions

View File

@ -324,7 +324,7 @@ public class ModelLoadingService implements ClusterStateListener {
// Populate loadingListeners key so we know that we are currently loading the model
for (String modelId : allReferencedModelKeys) {
loadingListeners.put(modelId, new ArrayDeque<>());
loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>());
}
} // synchronized (loadingListeners)
if (logger.isTraceEnabled()) {
@ -420,4 +420,25 @@ public class ModelLoadingService implements ClusterStateListener {
throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType);
}
}
/**
* Register a listener for notification when a model is loaded.
*
* This method is primarily intended for testing (hence package private)
* and shouldn't be required outside of testing.
*
* @param modelId Model Id
* @param modelLoadedListener To be notified
*/
void addModelLoadedListener(String modelId, ActionListener<Model> modelLoadedListener) {
synchronized (loadingListeners) {
loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
if (listenerQueue == null) {
return addFluently(new ArrayDeque<>(), modelLoadedListener);
} else {
return addFluently(listenerQueue, modelLoadedListener);
}
});
}
}
}

View File

@ -50,9 +50,13 @@ import org.mockito.ArgumentMatcher;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
import static org.hamcrest.Matchers.equalTo;
@ -145,7 +149,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/55251")
public void testMaxCachedLimitReached() throws Exception {
String model1 = "test-cached-limit-load-model-1";
String model2 = "test-cached-limit-load-model-2";
@ -164,6 +167,12 @@ public class ModelLoadingServiceTests extends ESTestCase {
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
"test-node");
// We want to be notified when the models are loaded which happens in a background thread
ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds));
for (String modelId : modelIds) {
modelLoadingService.addModelLoadedListener(modelId, loadedTracker.actionListener());
}
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
// Should have been loaded from the cluster change event but it is unknown in what order
@ -175,6 +184,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
});
// all models loaded put in the cache
assertBusy(() -> assertTrue(loadedTracker.allModelsLoaded()), 2, TimeUnit.SECONDS);
for(int i = 0; i < 10; i++) {
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
String model = modelIds[i%2];
@ -426,4 +438,28 @@ public class ModelLoadingServiceTests extends ESTestCase {
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
}
}
private static class ModelLoadedTracker {
private final Set<String> expectedModelIds;
ModelLoadedTracker(Collection<String> expectedModelIds) {
this.expectedModelIds = new HashSet<>(expectedModelIds);
}
private synchronized boolean allModelsLoaded() {
return expectedModelIds.isEmpty();
}
private synchronized void onModelLoaded(Model model) {
expectedModelIds.remove(model.getModelId());
}
private void onFailure(Exception e) {
fail(e.getMessage());
}
ActionListener<Model> actionListener() {
return ActionListener.wrap(this::onModelLoaded, this::onFailure);
}
}
}