From b87b1477045d5d10bb594464f9d3884f3dc3aaf1 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 11 Jun 2020 10:48:37 +0100 Subject: [PATCH] Add models for search to ModelLoadingService (#57592) (#57919) ModelLoadingService only caches models if they are referenced by an ingest pipeline. For models used in search we want to always cache the models and rely on TTL to evict them. Additionally when an ingest pipeline is deleted the model it references should not be evicted if it is used in search. --- .../TransportInternalInferModelAction.java | 4 +- .../inference/loadingservice/LocalModel.java | 2 - .../loadingservice/ModelLoadingService.java | 272 +++++++++++------- .../ModelLoadingServiceTests.java | 55 +++- 4 files changed, 219 insertions(+), 114 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d229f4decbe..2f072ad2c4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -75,13 +75,13 @@ public class TransportInternalInferModelAction extends HandledTransportAction { responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel())); if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) { - this.modelLoadingService.getModel(request.getModelId(), getModelListener); + this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener); } else { listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index d002f8deddf..83ce4a72e56 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -32,7 +32,6 @@ public class LocalModel implements Model { private final InferenceDefinition trainedModelDefinition; private final String modelId; - private final String nodeId; private final Set fieldNames; private final Map defaultFieldMap; private final InferenceStats.Accumulator statsAccumulator; @@ -50,7 +49,6 @@ public class LocalModel implements Model { TrainedModelStatsService trainedModelStatsService ) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; - this.nodeId = nodeId; this.fieldNames = new HashSet<>(input.getFieldNames()); this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId); this.trainedModelStatsService = trainedModelStatsService; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 436cdfd54f7..7cd00844899 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -43,6 +43,7 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -52,9 +53,11 @@ import java.util.Set; import java.util.concurrent.TimeUnit; /** - * This is a thread safe model loading service. + * This is a thread safe model loading service with LRU cache. + * Cache entries have a TTL before they are evicted. * - * It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node). + * In the case of a pipeline processor requesting the model if the processor is in simulate + * mode the model is not cached. All other uses will cache the model * * If more than one processor references the same model, that model will only be cached once. */ @@ -62,7 +65,7 @@ public class ModelLoadingService implements ClusterStateListener { /** * The maximum size of the local model cache here in the loading service - * + *

* Once the limit is reached, LRU models are evicted in favor of new models */ public static final Setting INFERENCE_MODEL_CACHE_SIZE = @@ -72,12 +75,11 @@ public class ModelLoadingService implements ClusterStateListener { /** * How long should a model stay in the cache since its last access - * + *

* If nothing references a model via getModel for this configured timeValue, it will be evicted. - * + *

* Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not * executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called. - * */ public static final Setting INFERENCE_MODEL_CACHE_TTL = Setting.timeSetting("xpack.ml.inference_model.time_to_live", @@ -85,9 +87,25 @@ public class ModelLoadingService implements ClusterStateListener { new TimeValue(1, TimeUnit.MILLISECONDS), Setting.Property.NodeScope); + // The feature requesting the model + public enum Consumer { + PIPELINE, SEARCH + } + + private static class ModelAndConsumer { + private LocalModel model; + private EnumSet consumers; + + private ModelAndConsumer(LocalModel model, Consumer consumer) { + this.model = model; + this.consumers = EnumSet.of(consumer); + } + } + + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); private final TrainedModelStatsService modelStatsService; - private final Cache localModelCache; + private final Cache localModelCache; private final Set referencedModels = new HashSet<>(); private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; @@ -112,9 +130,9 @@ public class ModelLoadingService implements ClusterStateListener { this.auditor = auditor; this.modelStatsService = modelStatsService; this.shouldNotAudit = new HashSet<>(); - this.localModelCache = CacheBuilder.builder() + this.localModelCache = CacheBuilder.builder() .setMaximumWeight(this.maxCacheSize.getBytes()) - .weigher((id, localModel) -> localModel.ramBytesUsed()) + .weigher((id, modelAndConsumer) -> modelAndConsumer.model.ramBytesUsed()) // explicit declaration of the listener lambda necessary for Eclipse IDE 4.14 .removalListener(notification -> cacheEvictionListener(notification)) .setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings)) @@ -124,101 +142,115 @@ public class ModelLoadingService implements ClusterStateListener { this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker"); } + boolean isModelCached(String modelId) { + return localModelCache.get(modelId) != null; + } + + /** + * Load the model for use by an ingest pipeline. The model will not be cached if there is no + * ingest pipeline referencing it i.e. it is used in simulate mode + * + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved + */ + public void getModelForPipeline(String modelId, ActionListener modelActionListener) { + getModel(modelId, Consumer.PIPELINE, modelActionListener); + } + + /** + * Load the model for use by at search. Models requested by search are always cached. + * + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved + */ + public void getModelForSearch(String modelId, ActionListener modelActionListener) { + getModel(modelId, Consumer.SEARCH, modelActionListener); + } + /** * Gets the model referenced by `modelId` and responds to the listener. - * + *

* This method first checks the local LRU cache for the model. If it is present, it is returned from cache. + *

+ * In the case of search if the model is not present one of the following occurs: + * - If it is currently being loaded the `modelActionListener` + * is added to the list of listeners to be alerted when the model is fully loaded. + * - Otherwise the model is loaded and cached * - * If it is not present, one of the following occurs: + * In the case of an ingest processor if it is not present, one of the following occurs: + *

+ * - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener` + * is added to the list of listeners to be alerted when the model is fully loaded. + * - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting + * model will attempt to be cached for future reference + * - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener. + * It is not cached. * - * - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener` - * is added to the list of listeners to be alerted when the model is fully loaded. - * - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting - * model will attempt to be cached for future reference - * - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener. - * It is not cached. + * The main difference being that models for search are always cached whereas pipeline models + * are only cached if they are referenced by an ingest pipeline * - * @param modelId the model to get + * @param modelId the model to get + * @param consumer which feature is requesting the model * @param modelActionListener the listener to alert when the model has been retrieved. */ - public void getModel(String modelId, ActionListener modelActionListener) { - LocalModel cachedModel = localModelCache.get(modelId); + private void getModel(String modelId, Consumer consumer, ActionListener modelActionListener) { + ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { - modelActionListener.onResponse(cachedModel); + cachedModel.consumers.add(consumer); + modelActionListener.onResponse(cachedModel.model); logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId)); return; } - if (loadModelIfNecessary(modelId, modelActionListener) == false) { - // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called - // by a simulated pipeline - logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); - provider.getTrainedModel(modelId, false, ActionListener.wrap( - trainedModelConfig -> { - // Verify we can pull the model into memory without causing OOM - trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); - provider.getTrainedModelForInference(modelId, ActionListener.wrap( - inferenceDefinition -> { - InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? - inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : - trainedModelConfig.getInferenceConfig(); - // Remove the bytes as we cannot control how long the caller will keep the model in memory - trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); - modelActionListener.onResponse(new LocalModel( - trainedModelConfig.getModelId(), - localNode, - inferenceDefinition, - trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap(), - inferenceConfig, - modelStatsService)); - }, - // Failure getting the definition, remove the initial estimation value - e -> { - trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); - modelActionListener.onFailure(e); - } - )); - }, - modelActionListener::onFailure - )); - } else { + + if (loadModelIfNecessary(modelId, consumer, modelActionListener)) { logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId)); } } /** - * Returns true if the model is loaded and the listener has been given the cached model - * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded - * Returns false if the model is not loaded or actively being loaded + * If the model is cached it is returned directly to the listener + * else if the model is CURRENTLY being loaded the listener is added to be notified when it is loaded + * else the model load is initiated. + * + * @param modelId The model to get + * @param consumer The model consumer + * @param modelActionListener The listener + * @return If the model is cached or currently being loaded true is returned. If a new load is started + * false is returned to indicate a new load event */ - private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener modelActionListener) { synchronized (loadingListeners) { - Model cachedModel = localModelCache.get(modelId); + ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { - modelActionListener.onResponse(cachedModel); + cachedModel.consumers.add(consumer); + modelActionListener.onResponse(cachedModel.model); return true; } - // It is referenced by a pipeline, but the cache does not contain it - if (referencedModels.contains(modelId)) { - // If the loaded model is referenced there but is not present, - // that means the previous load attempt failed or the model has been evicted - // Attempt to load and cache the model if necessary - if (loadingListeners.computeIfPresent( - modelId, - (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { - logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId)); - loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); - loadModel(modelId); - } + + // Add the listener to the queue if the model is loading + Queue> listeners = loadingListeners.computeIfPresent(modelId, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)); + + // The cachedModel entry is null, but there are listeners present, that means it is being loaded + if (listeners != null) { return true; } - // if the cachedModel entry is null, but there are listeners present, that means it is being loaded - return loadingListeners.computeIfPresent(modelId, - (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; + + if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) { + // The model is requested by a pipeline but not referenced by any ingest pipelines. + // This means it is a simulate call and the model should not be cached + loadWithoutCaching(modelId, modelActionListener); + } else { + logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId)); + loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(modelId, consumer); + } + + return false; } // synchronized (loadingListeners) } - private void loadModel(String modelId) { + private void loadModel(String modelId, Consumer consumer) { provider.getTrainedModel(modelId, false, ActionListener.wrap( trainedModelConfig -> { trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); @@ -238,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener { return; } } - handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition); + handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition); }, failure -> { // We failed to get the definition, remove the initial estimation. @@ -255,7 +287,43 @@ public class ModelLoadingService implements ClusterStateListener { )); } + private void loadWithoutCaching(String modelId, ActionListener modelActionListener) { + // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called + // by a simulated pipeline + logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); + provider.getTrainedModel(modelId, false, ActionListener.wrap( + trainedModelConfig -> { + // Verify we can pull the model into memory without causing OOM + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); + provider.getTrainedModelForInference(modelId, ActionListener.wrap( + inferenceDefinition -> { + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : + trainedModelConfig.getInferenceConfig(); + // Remove the bytes as we cannot control how long the caller will keep the model in memory + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onResponse(new LocalModel( + trainedModelConfig.getModelId(), + localNode, + inferenceDefinition, + trainedModelConfig.getInput(), + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig, + modelStatsService)); + }, + // Failure getting the definition, remove the initial estimation value + e -> { + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onFailure(e); + } + )); + }, + modelActionListener::onFailure + )); + } + private void handleLoadSuccess(String modelId, + Consumer consumer, TrainedModelConfig trainedModelConfig, InferenceDefinition inferenceDefinition) { Queue> listeners; @@ -278,7 +346,7 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed()); return; } - localModelCache.put(modelId, loadedModel); + localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer)); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { @@ -301,7 +369,7 @@ public class ModelLoadingService implements ClusterStateListener { } } - private void cacheEvictionListener(RemovalNotification notification) { + private void cacheEvictionListener(RemovalNotification notification) { try { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { MessageSupplier msg = () -> new ParameterizedMessage( @@ -310,15 +378,15 @@ public class ModelLoadingService implements ClusterStateListener { "If this is undesired, consider updating setting [{}] or [{}].", new ByteSizeValue(localModelCache.weight()).getStringRep(), maxCacheSize.getStringRep(), - new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), + new ByteSizeValue(notification.getValue().model.ramBytesUsed()).getStringRep(), INFERENCE_MODEL_CACHE_SIZE.getKey(), INFERENCE_MODEL_CACHE_TTL.getKey()); auditIfNecessary(notification.getKey(), msg); } // If the model is no longer referenced, flush the stats to persist as soon as possible - notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); + notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); } finally { - trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed()); + trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed()); } } @@ -356,7 +424,13 @@ public class ModelLoadingService implements ClusterStateListener { removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys); // Remove all cached models that are not referenced by any processors - removedModels.forEach(localModelCache::invalidate); + // and are not used in search + removedModels.forEach(modelId -> { + ModelAndConsumer modelAndConsumer = localModelCache.get(modelId); + if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) { + localModelCache.invalidate(modelId); + } + }); // Remove the models that are no longer referenced referencedModels.removeAll(removedModels); shouldNotAudit.removeAll(removedModels); @@ -389,7 +463,7 @@ public class ModelLoadingService implements ClusterStateListener { } } removedModels.forEach(this::auditUnreferencedModel); - loadModels(allReferencedModelKeys); + loadModelsForPipeline(allReferencedModelKeys); } private void auditIfNecessary(String modelId, MessageSupplier msg) { @@ -402,7 +476,7 @@ public class ModelLoadingService implements ClusterStateListener { logger.info("[{}] {}", modelId, msg.get().getFormattedMessage()); } - private void loadModels(Set modelIds) { + private void loadModelsForPipeline(Set modelIds) { if (modelIds.isEmpty()) { return; } @@ -410,7 +484,7 @@ public class ModelLoadingService implements ClusterStateListener { threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { for (String modelId : modelIds) { auditNewReferencedModel(modelId); - this.loadModel(modelId); + this.loadModel(modelId, Consumer.PIPELINE); } }); } @@ -436,11 +510,11 @@ public class ModelLoadingService implements ClusterStateListener { ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); if (processors instanceof List) { - for(Object processor : (List)processors) { + for (Object processor : (List) processors) { if (processor instanceof Map) { - Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + Object processorConfig = ((Map) processor).get(InferenceProcessor.TYPE); if (processorConfig instanceof Map) { - Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + Object modelId = ((Map) processorConfig).get(InferenceProcessor.MODEL_ID); if (modelId != null) { assert modelId instanceof String; allReferencedModelKeys.add(modelId.toString()); @@ -454,7 +528,7 @@ public class ModelLoadingService implements ClusterStateListener { } private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) { - switch(targetType) { + switch (targetType) { case REGRESSION: return RegressionConfig.EMPTY_PARAMS; case CLASSIFICATION: @@ -466,22 +540,22 @@ public class ModelLoadingService implements ClusterStateListener { /** * Register a listener for notification when a model is loaded. - * + *

* This method is primarily intended for testing (hence package private) * and shouldn't be required outside of testing. * - * @param modelId Model Id + * @param modelId Model Id * @param modelLoadedListener To be notified */ void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { synchronized (loadingListeners) { loadingListeners.compute(modelId, (modelKey, listenerQueue) -> { - if (listenerQueue == null) { - return addFluently(new ArrayDeque<>(), modelLoadedListener); - } else { - return addFluently(listenerQueue, modelLoadedListener); - } - }); + if (listenerQueue == null) { + return addFluently(new ArrayDeque<>(), modelLoadedListener); + } else { + return addFluently(listenerQueue, modelLoadedListener); + } + }); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 98f38bcd52c..5c2f18cc29e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -131,7 +131,7 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -139,12 +139,16 @@ public class ModelLoadingServiceTests extends ESTestCase { verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any()); + assertTrue(modelLoadingService.isModelCached(model1)); + assertTrue(modelLoadingService.isModelCached(model2)); + assertTrue(modelLoadingService.isModelCached(model3)); + // Test invalidate cache for model3 modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -196,7 +200,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -219,7 +223,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 3, should invalidate 1 and 2 for(int i = 0; i < 10; i++) { PlainActionFuture future3 = new PlainActionFuture<>(); - modelLoadingService.getModel(model3, future3); + modelLoadingService.getModelForPipeline(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), any()); @@ -240,7 +244,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 1, should invalidate 3 for(int i = 0; i < 10; i++) { PlainActionFuture future1 = new PlainActionFuture<>(); - modelLoadingService.getModel(model1, future1); + modelLoadingService.getModelForPipeline(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any()); @@ -254,7 +258,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 2 for(int i = 0; i < 10; i++) { PlainActionFuture future2 = new PlainActionFuture<>(); - modelLoadingService.getModel(model2, future2); + modelLoadingService.getModelForPipeline(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any()); @@ -265,7 +269,7 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -292,10 +296,11 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model1, future); + modelLoadingService.getModelForPipeline(model1, future); assertThat(future.get(), is(not(nullValue()))); } + assertFalse(modelLoadingService.isModelCached(model1)); verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @@ -315,7 +320,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(model)); PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); try { future.get(); @@ -323,6 +328,7 @@ public class ModelLoadingServiceTests extends ESTestCase { } catch (Exception ex) { assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } + assertFalse(modelLoadingService.isModelCached(model)); verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); @@ -342,13 +348,14 @@ public class ModelLoadingServiceTests extends ESTestCase { circuitBreaker); PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); try { future.get(); fail("Should not have succeeded"); } catch (Exception ex) { assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } + assertFalse(modelLoadingService.isModelCached(model)); } public void testGetModelEagerly() throws Exception { @@ -366,11 +373,37 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 3; i++) { PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), any()); + assertFalse(modelLoadingService.isModelCached(model)); + verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); + } + + public void testGetModelForSearch() throws Exception { + String modelId = "test-get-model-for-search"; + withTrainedModel(modelId, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + for(int i = 0; i < 3; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForSearch(modelId, future); + assertThat(future.get(), is(not(nullValue()))); + } + + assertTrue(modelLoadingService.isModelCached(modelId)); + + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(modelId), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); }