mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-09 06:25:07 +00:00
[ML] Wait for model loaded and cached in ModelLoadingServiceTests (#56014)
Fixes test by exposing the method ModelLoadingService::addModelLoadedListener() so that the test class can be notified when a model is loaded which happens in a background thread
This commit is contained in:
parent
317d9fb88f
commit
c204353249
@ -324,7 +324,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||||||
|
|
||||||
// Populate loadingListeners key so we know that we are currently loading the model
|
// Populate loadingListeners key so we know that we are currently loading the model
|
||||||
for (String modelId : allReferencedModelKeys) {
|
for (String modelId : allReferencedModelKeys) {
|
||||||
loadingListeners.put(modelId, new ArrayDeque<>());
|
loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>());
|
||||||
}
|
}
|
||||||
} // synchronized (loadingListeners)
|
} // synchronized (loadingListeners)
|
||||||
if (logger.isTraceEnabled()) {
|
if (logger.isTraceEnabled()) {
|
||||||
@ -420,4 +420,25 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||||||
throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType);
|
throw ExceptionsHelper.badRequestException("unsupported target type [{}]", targetType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a listener for notification when a model is loaded.
|
||||||
|
*
|
||||||
|
* This method is primarily intended for testing (hence package private)
|
||||||
|
* and shouldn't be required outside of testing.
|
||||||
|
*
|
||||||
|
* @param modelId Model Id
|
||||||
|
* @param modelLoadedListener To be notified
|
||||||
|
*/
|
||||||
|
void addModelLoadedListener(String modelId, ActionListener<Model> modelLoadedListener) {
|
||||||
|
synchronized (loadingListeners) {
|
||||||
|
loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
|
||||||
|
if (listenerQueue == null) {
|
||||||
|
return addFluently(new ArrayDeque<>(), modelLoadedListener);
|
||||||
|
} else {
|
||||||
|
return addFluently(listenerQueue, modelLoadedListener);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -50,9 +50,13 @@ import org.mockito.ArgumentMatcher;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.InetAddress;
|
import java.net.InetAddress;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
@ -145,7 +149,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||||||
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
}
|
}
|
||||||
|
|
||||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/55251")
|
|
||||||
public void testMaxCachedLimitReached() throws Exception {
|
public void testMaxCachedLimitReached() throws Exception {
|
||||||
String model1 = "test-cached-limit-load-model-1";
|
String model1 = "test-cached-limit-load-model-1";
|
||||||
String model2 = "test-cached-limit-load-model-2";
|
String model2 = "test-cached-limit-load-model-2";
|
||||||
@ -164,6 +167,12 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||||||
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
|
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
|
||||||
"test-node");
|
"test-node");
|
||||||
|
|
||||||
|
// We want to be notified when the models are loaded which happens in a background thread
|
||||||
|
ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds));
|
||||||
|
for (String modelId : modelIds) {
|
||||||
|
modelLoadingService.addModelLoadedListener(modelId, loadedTracker.actionListener());
|
||||||
|
}
|
||||||
|
|
||||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||||
|
|
||||||
// Should have been loaded from the cluster change event but it is unknown in what order
|
// Should have been loaded from the cluster change event but it is unknown in what order
|
||||||
@ -175,6 +184,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// all models loaded put in the cache
|
||||||
|
assertBusy(() -> assertTrue(loadedTracker.allModelsLoaded()), 2, TimeUnit.SECONDS);
|
||||||
|
|
||||||
for(int i = 0; i < 10; i++) {
|
for(int i = 0; i < 10; i++) {
|
||||||
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
||||||
String model = modelIds[i%2];
|
String model = modelIds[i%2];
|
||||||
@ -426,4 +438,28 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||||||
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class ModelLoadedTracker {
|
||||||
|
private final Set<String> expectedModelIds;
|
||||||
|
|
||||||
|
ModelLoadedTracker(Collection<String> expectedModelIds) {
|
||||||
|
this.expectedModelIds = new HashSet<>(expectedModelIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
private synchronized boolean allModelsLoaded() {
|
||||||
|
return expectedModelIds.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
private synchronized void onModelLoaded(Model model) {
|
||||||
|
expectedModelIds.remove(model.getModelId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void onFailure(Exception e) {
|
||||||
|
fail(e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
ActionListener<Model> actionListener() {
|
||||||
|
return ActionListener.wrap(this::onModelLoaded, this::onFailure);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user