diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d229f4decbe..2f072ad2c4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -75,13 +75,13 @@ public class TransportInternalInferModelAction extends HandledTransportAction { responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel())); if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) { - this.modelLoadingService.getModel(request.getModelId(), getModelListener); + this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener); } else { listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index d002f8deddf..83ce4a72e56 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -32,7 +32,6 @@ public class LocalModel implements Model { private final InferenceDefinition trainedModelDefinition; private final String modelId; - private final String nodeId; private final Set fieldNames; private final Map defaultFieldMap; private final InferenceStats.Accumulator statsAccumulator; @@ -50,7 +49,6 @@ public class LocalModel implements Model { TrainedModelStatsService trainedModelStatsService ) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; - this.nodeId = nodeId; this.fieldNames = new HashSet<>(input.getFieldNames()); this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId); this.trainedModelStatsService = trainedModelStatsService; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 436cdfd54f7..7cd00844899 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -43,6 +43,7 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -52,9 +53,11 @@ import java.util.Set; import java.util.concurrent.TimeUnit; /** - * This is a thread safe model loading service. + * This is a thread safe model loading service with LRU cache. + * Cache entries have a TTL before they are evicted. * - * It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node). + * In the case of a pipeline processor requesting the model if the processor is in simulate + * mode the model is not cached. All other uses will cache the model * * If more than one processor references the same model, that model will only be cached once. */ @@ -62,7 +65,7 @@ public class ModelLoadingService implements ClusterStateListener { /** * The maximum size of the local model cache here in the loading service - * + *

* Once the limit is reached, LRU models are evicted in favor of new models */ public static final Setting INFERENCE_MODEL_CACHE_SIZE = @@ -72,12 +75,11 @@ public class ModelLoadingService implements ClusterStateListener { /** * How long should a model stay in the cache since its last access - * + *

* If nothing references a model via getModel for this configured timeValue, it will be evicted. - * + *

* Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not * executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called. - * */ public static final Setting INFERENCE_MODEL_CACHE_TTL = Setting.timeSetting("xpack.ml.inference_model.time_to_live", @@ -85,9 +87,25 @@ public class ModelLoadingService implements ClusterStateListener { new TimeValue(1, TimeUnit.MILLISECONDS), Setting.Property.NodeScope); + // The feature requesting the model + public enum Consumer { + PIPELINE, SEARCH + } + + private static class ModelAndConsumer { + private LocalModel model; + private EnumSet consumers; + + private ModelAndConsumer(LocalModel model, Consumer consumer) { + this.model = model; + this.consumers = EnumSet.of(consumer); + } + } + + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); private final TrainedModelStatsService modelStatsService; - private final Cache localModelCache; + private final Cache localModelCache; private final Set referencedModels = new HashSet<>(); private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; @@ -112,9 +130,9 @@ public class ModelLoadingService implements ClusterStateListener { this.auditor = auditor; this.modelStatsService = modelStatsService; this.shouldNotAudit = new HashSet<>(); - this.localModelCache = CacheBuilder.builder() + this.localModelCache = CacheBuilder.builder() .setMaximumWeight(this.maxCacheSize.getBytes()) - .weigher((id, localModel) -> localModel.ramBytesUsed()) + .weigher((id, modelAndConsumer) -> modelAndConsumer.model.ramBytesUsed()) // explicit declaration of the listener lambda necessary for Eclipse IDE 4.14 .removalListener(notification -> cacheEvictionListener(notification)) .setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings)) @@ -124,101 +142,115 @@ public class ModelLoadingService implements ClusterStateListener { this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker"); } + boolean isModelCached(String modelId) { + return localModelCache.get(modelId) != null; + } + + /** + * Load the model for use by an ingest pipeline. The model will not be cached if there is no + * ingest pipeline referencing it i.e. it is used in simulate mode + * + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved + */ + public void getModelForPipeline(String modelId, ActionListener modelActionListener) { + getModel(modelId, Consumer.PIPELINE, modelActionListener); + } + + /** + * Load the model for use by at search. Models requested by search are always cached. + * + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved + */ + public void getModelForSearch(String modelId, ActionListener modelActionListener) { + getModel(modelId, Consumer.SEARCH, modelActionListener); + } + /** * Gets the model referenced by `modelId` and responds to the listener. - * + *

* This method first checks the local LRU cache for the model. If it is present, it is returned from cache. + *

+ * In the case of search if the model is not present one of the following occurs: + * - If it is currently being loaded the `modelActionListener` + * is added to the list of listeners to be alerted when the model is fully loaded. + * - Otherwise the model is loaded and cached * - * If it is not present, one of the following occurs: + * In the case of an ingest processor if it is not present, one of the following occurs: + *

+ * - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener` + * is added to the list of listeners to be alerted when the model is fully loaded. + * - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting + * model will attempt to be cached for future reference + * - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener. + * It is not cached. * - * - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener` - * is added to the list of listeners to be alerted when the model is fully loaded. - * - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting - * model will attempt to be cached for future reference - * - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener. - * It is not cached. + * The main difference being that models for search are always cached whereas pipeline models + * are only cached if they are referenced by an ingest pipeline * - * @param modelId the model to get + * @param modelId the model to get + * @param consumer which feature is requesting the model * @param modelActionListener the listener to alert when the model has been retrieved. */ - public void getModel(String modelId, ActionListener modelActionListener) { - LocalModel cachedModel = localModelCache.get(modelId); + private void getModel(String modelId, Consumer consumer, ActionListener modelActionListener) { + ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { - modelActionListener.onResponse(cachedModel); + cachedModel.consumers.add(consumer); + modelActionListener.onResponse(cachedModel.model); logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId)); return; } - if (loadModelIfNecessary(modelId, modelActionListener) == false) { - // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called - // by a simulated pipeline - logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); - provider.getTrainedModel(modelId, false, ActionListener.wrap( - trainedModelConfig -> { - // Verify we can pull the model into memory without causing OOM - trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); - provider.getTrainedModelForInference(modelId, ActionListener.wrap( - inferenceDefinition -> { - InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? - inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : - trainedModelConfig.getInferenceConfig(); - // Remove the bytes as we cannot control how long the caller will keep the model in memory - trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); - modelActionListener.onResponse(new LocalModel( - trainedModelConfig.getModelId(), - localNode, - inferenceDefinition, - trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap(), - inferenceConfig, - modelStatsService)); - }, - // Failure getting the definition, remove the initial estimation value - e -> { - trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); - modelActionListener.onFailure(e); - } - )); - }, - modelActionListener::onFailure - )); - } else { + + if (loadModelIfNecessary(modelId, consumer, modelActionListener)) { logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId)); } } /** - * Returns true if the model is loaded and the listener has been given the cached model - * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded - * Returns false if the model is not loaded or actively being loaded + * If the model is cached it is returned directly to the listener + * else if the model is CURRENTLY being loaded the listener is added to be notified when it is loaded + * else the model load is initiated. + * + * @param modelId The model to get + * @param consumer The model consumer + * @param modelActionListener The listener + * @return If the model is cached or currently being loaded true is returned. If a new load is started + * false is returned to indicate a new load event */ - private boolean loadModelIfNecessary(String modelId, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener modelActionListener) { synchronized (loadingListeners) { - Model cachedModel = localModelCache.get(modelId); + ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { - modelActionListener.onResponse(cachedModel); + cachedModel.consumers.add(consumer); + modelActionListener.onResponse(cachedModel.model); return true; } - // It is referenced by a pipeline, but the cache does not contain it - if (referencedModels.contains(modelId)) { - // If the loaded model is referenced there but is not present, - // that means the previous load attempt failed or the model has been evicted - // Attempt to load and cache the model if necessary - if (loadingListeners.computeIfPresent( - modelId, - (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { - logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId)); - loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); - loadModel(modelId); - } + + // Add the listener to the queue if the model is loading + Queue> listeners = loadingListeners.computeIfPresent(modelId, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)); + + // The cachedModel entry is null, but there are listeners present, that means it is being loaded + if (listeners != null) { return true; } - // if the cachedModel entry is null, but there are listeners present, that means it is being loaded - return loadingListeners.computeIfPresent(modelId, - (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; + + if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) { + // The model is requested by a pipeline but not referenced by any ingest pipelines. + // This means it is a simulate call and the model should not be cached + loadWithoutCaching(modelId, modelActionListener); + } else { + logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId)); + loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(modelId, consumer); + } + + return false; } // synchronized (loadingListeners) } - private void loadModel(String modelId) { + private void loadModel(String modelId, Consumer consumer) { provider.getTrainedModel(modelId, false, ActionListener.wrap( trainedModelConfig -> { trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); @@ -238,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener { return; } } - handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition); + handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition); }, failure -> { // We failed to get the definition, remove the initial estimation. @@ -255,7 +287,43 @@ public class ModelLoadingService implements ClusterStateListener { )); } + private void loadWithoutCaching(String modelId, ActionListener modelActionListener) { + // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called + // by a simulated pipeline + logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); + provider.getTrainedModel(modelId, false, ActionListener.wrap( + trainedModelConfig -> { + // Verify we can pull the model into memory without causing OOM + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); + provider.getTrainedModelForInference(modelId, ActionListener.wrap( + inferenceDefinition -> { + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : + trainedModelConfig.getInferenceConfig(); + // Remove the bytes as we cannot control how long the caller will keep the model in memory + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onResponse(new LocalModel( + trainedModelConfig.getModelId(), + localNode, + inferenceDefinition, + trainedModelConfig.getInput(), + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig, + modelStatsService)); + }, + // Failure getting the definition, remove the initial estimation value + e -> { + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onFailure(e); + } + )); + }, + modelActionListener::onFailure + )); + } + private void handleLoadSuccess(String modelId, + Consumer consumer, TrainedModelConfig trainedModelConfig, InferenceDefinition inferenceDefinition) { Queue> listeners; @@ -278,7 +346,7 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed()); return; } - localModelCache.put(modelId, loadedModel); + localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer)); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { @@ -301,7 +369,7 @@ public class ModelLoadingService implements ClusterStateListener { } } - private void cacheEvictionListener(RemovalNotification notification) { + private void cacheEvictionListener(RemovalNotification notification) { try { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { MessageSupplier msg = () -> new ParameterizedMessage( @@ -310,15 +378,15 @@ public class ModelLoadingService implements ClusterStateListener { "If this is undesired, consider updating setting [{}] or [{}].", new ByteSizeValue(localModelCache.weight()).getStringRep(), maxCacheSize.getStringRep(), - new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), + new ByteSizeValue(notification.getValue().model.ramBytesUsed()).getStringRep(), INFERENCE_MODEL_CACHE_SIZE.getKey(), INFERENCE_MODEL_CACHE_TTL.getKey()); auditIfNecessary(notification.getKey(), msg); } // If the model is no longer referenced, flush the stats to persist as soon as possible - notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); + notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); } finally { - trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed()); + trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed()); } } @@ -356,7 +424,13 @@ public class ModelLoadingService implements ClusterStateListener { removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys); // Remove all cached models that are not referenced by any processors - removedModels.forEach(localModelCache::invalidate); + // and are not used in search + removedModels.forEach(modelId -> { + ModelAndConsumer modelAndConsumer = localModelCache.get(modelId); + if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) { + localModelCache.invalidate(modelId); + } + }); // Remove the models that are no longer referenced referencedModels.removeAll(removedModels); shouldNotAudit.removeAll(removedModels); @@ -389,7 +463,7 @@ public class ModelLoadingService implements ClusterStateListener { } } removedModels.forEach(this::auditUnreferencedModel); - loadModels(allReferencedModelKeys); + loadModelsForPipeline(allReferencedModelKeys); } private void auditIfNecessary(String modelId, MessageSupplier msg) { @@ -402,7 +476,7 @@ public class ModelLoadingService implements ClusterStateListener { logger.info("[{}] {}", modelId, msg.get().getFormattedMessage()); } - private void loadModels(Set modelIds) { + private void loadModelsForPipeline(Set modelIds) { if (modelIds.isEmpty()) { return; } @@ -410,7 +484,7 @@ public class ModelLoadingService implements ClusterStateListener { threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { for (String modelId : modelIds) { auditNewReferencedModel(modelId); - this.loadModel(modelId); + this.loadModel(modelId, Consumer.PIPELINE); } }); } @@ -436,11 +510,11 @@ public class ModelLoadingService implements ClusterStateListener { ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); if (processors instanceof List) { - for(Object processor : (List)processors) { + for (Object processor : (List) processors) { if (processor instanceof Map) { - Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + Object processorConfig = ((Map) processor).get(InferenceProcessor.TYPE); if (processorConfig instanceof Map) { - Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + Object modelId = ((Map) processorConfig).get(InferenceProcessor.MODEL_ID); if (modelId != null) { assert modelId instanceof String; allReferencedModelKeys.add(modelId.toString()); @@ -454,7 +528,7 @@ public class ModelLoadingService implements ClusterStateListener { } private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) { - switch(targetType) { + switch (targetType) { case REGRESSION: return RegressionConfig.EMPTY_PARAMS; case CLASSIFICATION: @@ -466,22 +540,22 @@ public class ModelLoadingService implements ClusterStateListener { /** * Register a listener for notification when a model is loaded. - * + *

* This method is primarily intended for testing (hence package private) * and shouldn't be required outside of testing. * - * @param modelId Model Id + * @param modelId Model Id * @param modelLoadedListener To be notified */ void addModelLoadedListener(String modelId, ActionListener modelLoadedListener) { synchronized (loadingListeners) { loadingListeners.compute(modelId, (modelKey, listenerQueue) -> { - if (listenerQueue == null) { - return addFluently(new ArrayDeque<>(), modelLoadedListener); - } else { - return addFluently(listenerQueue, modelLoadedListener); - } - }); + if (listenerQueue == null) { + return addFluently(new ArrayDeque<>(), modelLoadedListener); + } else { + return addFluently(listenerQueue, modelLoadedListener); + } + }); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 98f38bcd52c..5c2f18cc29e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -131,7 +131,7 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -139,12 +139,16 @@ public class ModelLoadingServiceTests extends ESTestCase { verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any()); + assertTrue(modelLoadingService.isModelCached(model1)); + assertTrue(modelLoadingService.isModelCached(model2)); + assertTrue(modelLoadingService.isModelCached(model3)); + // Test invalidate cache for model3 modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -196,7 +200,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -219,7 +223,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 3, should invalidate 1 and 2 for(int i = 0; i < 10; i++) { PlainActionFuture future3 = new PlainActionFuture<>(); - modelLoadingService.getModel(model3, future3); + modelLoadingService.getModelForPipeline(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), any()); @@ -240,7 +244,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 1, should invalidate 3 for(int i = 0; i < 10; i++) { PlainActionFuture future1 = new PlainActionFuture<>(); - modelLoadingService.getModel(model1, future1); + modelLoadingService.getModelForPipeline(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any()); @@ -254,7 +258,7 @@ public class ModelLoadingServiceTests extends ESTestCase { // Load model 2 for(int i = 0; i < 10; i++) { PlainActionFuture future2 = new PlainActionFuture<>(); - modelLoadingService.getModel(model2, future2); + modelLoadingService.getModelForPipeline(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any()); @@ -265,7 +269,7 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } @@ -292,10 +296,11 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 10; i++) { PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model1, future); + modelLoadingService.getModelForPipeline(model1, future); assertThat(future.get(), is(not(nullValue()))); } + assertFalse(modelLoadingService.isModelCached(model1)); verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @@ -315,7 +320,7 @@ public class ModelLoadingServiceTests extends ESTestCase { modelLoadingService.clusterChanged(ingestChangedEvent(model)); PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); try { future.get(); @@ -323,6 +328,7 @@ public class ModelLoadingServiceTests extends ESTestCase { } catch (Exception ex) { assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } + assertFalse(modelLoadingService.isModelCached(model)); verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); @@ -342,13 +348,14 @@ public class ModelLoadingServiceTests extends ESTestCase { circuitBreaker); PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); try { future.get(); fail("Should not have succeeded"); } catch (Exception ex) { assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model))); } + assertFalse(modelLoadingService.isModelCached(model)); } public void testGetModelEagerly() throws Exception { @@ -366,11 +373,37 @@ public class ModelLoadingServiceTests extends ESTestCase { for(int i = 0; i < 3; i++) { PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModel(model, future); + modelLoadingService.getModelForPipeline(model, future); assertThat(future.get(), is(not(nullValue()))); } verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), any()); + assertFalse(modelLoadingService.isModelCached(model)); + verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); + } + + public void testGetModelForSearch() throws Exception { + String modelId = "test-get-model-for-search"; + withTrainedModel(modelId, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + for(int i = 0; i < 3; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForSearch(modelId, future); + assertThat(future.get(), is(not(nullValue()))); + } + + assertTrue(modelLoadingService.isModelCached(modelId)); + + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(modelId), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); }