Add models for search to ModelLoadingService (#57592) (#57919)

ModelLoadingService only caches models if they are referenced by an 
ingest pipeline. For models used in search we want to always cache the
models and rely on TTL to evict them. Additionally when an ingest 
pipeline is deleted the model it references should not be evicted if 
it is used in search.
This commit is contained in:
David Kyle 2020-06-11 10:48:37 +01:00 committed by GitHub
parent 2905a2f623
commit b87b147704
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 114 deletions

View File

@ -75,13 +75,13 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
if (licenseState.isAllowed(XPackLicenseState.Feature.MACHINE_LEARNING)) {
responseBuilder.setLicensed(true);
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
} else {
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
trainedModelConfig -> {
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}

View File

@ -32,7 +32,6 @@ public class LocalModel implements Model {
private final InferenceDefinition trainedModelDefinition;
private final String modelId;
private final String nodeId;
private final Set<String> fieldNames;
private final Map<String, String> defaultFieldMap;
private final InferenceStats.Accumulator statsAccumulator;
@ -50,7 +49,6 @@ public class LocalModel implements Model {
TrainedModelStatsService trainedModelStatsService ) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.nodeId = nodeId;
this.fieldNames = new HashSet<>(input.getFieldNames());
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
this.trainedModelStatsService = trainedModelStatsService;

View File

@ -43,6 +43,7 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@ -52,9 +53,11 @@ import java.util.Set;
import java.util.concurrent.TimeUnit;
/**
* This is a thread safe model loading service.
* This is a thread safe model loading service with LRU cache.
* Cache entries have a TTL before they are evicted.
*
* It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node).
* In the case of a pipeline processor requesting the model if the processor is in simulate
* mode the model is not cached. All other uses will cache the model
*
* If more than one processor references the same model, that model will only be cached once.
*/
@ -62,7 +65,7 @@ public class ModelLoadingService implements ClusterStateListener {
/**
* The maximum size of the local model cache here in the loading service
*
* <p>
* Once the limit is reached, LRU models are evicted in favor of new models
*/
public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE =
@ -72,12 +75,11 @@ public class ModelLoadingService implements ClusterStateListener {
/**
* How long should a model stay in the cache since its last access
*
* <p>
* If nothing references a model via getModel for this configured timeValue, it will be evicted.
*
* <p>
* Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not
* executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called.
*
*/
public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL =
Setting.timeSetting("xpack.ml.inference_model.time_to_live",
@ -85,9 +87,25 @@ public class ModelLoadingService implements ClusterStateListener {
new TimeValue(1, TimeUnit.MILLISECONDS),
Setting.Property.NodeScope);
// The feature requesting the model
public enum Consumer {
PIPELINE, SEARCH
}
private static class ModelAndConsumer {
private LocalModel model;
private EnumSet<Consumer> consumers;
private ModelAndConsumer(LocalModel model, Consumer consumer) {
this.model = model;
this.consumers = EnumSet.of(consumer);
}
}
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
private final TrainedModelStatsService modelStatsService;
private final Cache<String, LocalModel> localModelCache;
private final Cache<String, ModelAndConsumer> localModelCache;
private final Set<String> referencedModels = new HashSet<>();
private final Map<String, Queue<ActionListener<Model>>> loadingListeners = new HashMap<>();
private final TrainedModelProvider provider;
@ -112,9 +130,9 @@ public class ModelLoadingService implements ClusterStateListener {
this.auditor = auditor;
this.modelStatsService = modelStatsService;
this.shouldNotAudit = new HashSet<>();
this.localModelCache = CacheBuilder.<String, LocalModel>builder()
this.localModelCache = CacheBuilder.<String, ModelAndConsumer>builder()
.setMaximumWeight(this.maxCacheSize.getBytes())
.weigher((id, localModel) -> localModel.ramBytesUsed())
.weigher((id, modelAndConsumer) -> modelAndConsumer.model.ramBytesUsed())
// explicit declaration of the listener lambda necessary for Eclipse IDE 4.14
.removalListener(notification -> cacheEvictionListener(notification))
.setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings))
@ -124,13 +142,43 @@ public class ModelLoadingService implements ClusterStateListener {
this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker");
}
boolean isModelCached(String modelId) {
return localModelCache.get(modelId) != null;
}
/**
* Load the model for use by an ingest pipeline. The model will not be cached if there is no
* ingest pipeline referencing it i.e. it is used in simulate mode
*
* @param modelId the model to get
* @param modelActionListener the listener to alert when the model has been retrieved
*/
public void getModelForPipeline(String modelId, ActionListener<Model> modelActionListener) {
getModel(modelId, Consumer.PIPELINE, modelActionListener);
}
/**
* Load the model for use by at search. Models requested by search are always cached.
*
* @param modelId the model to get
* @param modelActionListener the listener to alert when the model has been retrieved
*/
public void getModelForSearch(String modelId, ActionListener<Model> modelActionListener) {
getModel(modelId, Consumer.SEARCH, modelActionListener);
}
/**
* Gets the model referenced by `modelId` and responds to the listener.
*
* <p>
* This method first checks the local LRU cache for the model. If it is present, it is returned from cache.
* <p>
* In the case of search if the model is not present one of the following occurs:
* - If it is currently being loaded the `modelActionListener`
* is added to the list of listeners to be alerted when the model is fully loaded.
* - Otherwise the model is loaded and cached
*
* If it is not present, one of the following occurs:
*
* In the case of an ingest processor if it is not present, one of the following occurs:
* <p>
* - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener`
* is added to the list of listeners to be alerted when the model is fully loaded.
* - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting
@ -138,17 +186,108 @@ public class ModelLoadingService implements ClusterStateListener {
* - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener.
* It is not cached.
*
* The main difference being that models for search are always cached whereas pipeline models
* are only cached if they are referenced by an ingest pipeline
*
* @param modelId the model to get
* @param consumer which feature is requesting the model
* @param modelActionListener the listener to alert when the model has been retrieved.
*/
public void getModel(String modelId, ActionListener<Model> modelActionListener) {
LocalModel cachedModel = localModelCache.get(modelId);
private void getModel(String modelId, Consumer consumer, ActionListener<Model> modelActionListener) {
ModelAndConsumer cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
cachedModel.consumers.add(consumer);
modelActionListener.onResponse(cachedModel.model);
logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
return;
}
if (loadModelIfNecessary(modelId, modelActionListener) == false) {
if (loadModelIfNecessary(modelId, consumer, modelActionListener)) {
logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
}
}
/**
* If the model is cached it is returned directly to the listener
* else if the model is CURRENTLY being loaded the listener is added to be notified when it is loaded
* else the model load is initiated.
*
* @param modelId The model to get
* @param consumer The model consumer
* @param modelActionListener The listener
* @return If the model is cached or currently being loaded true is returned. If a new load is started
* false is returned to indicate a new load event
*/
private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener<Model> modelActionListener) {
synchronized (loadingListeners) {
ModelAndConsumer cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
cachedModel.consumers.add(consumer);
modelActionListener.onResponse(cachedModel.model);
return true;
}
// Add the listener to the queue if the model is loading
Queue<ActionListener<Model>> listeners = loadingListeners.computeIfPresent(modelId,
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener));
// The cachedModel entry is null, but there are listeners present, that means it is being loaded
if (listeners != null) {
return true;
}
if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) {
// The model is requested by a pipeline but not referenced by any ingest pipelines.
// This means it is a simulate call and the model should not be cached
loadWithoutCaching(modelId, modelActionListener);
} else {
logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
loadModel(modelId, consumer);
}
return false;
} // synchronized (loadingListeners)
}
private void loadModel(String modelId, Consumer consumer) {
provider.getTrainedModel(modelId, false, ActionListener.wrap(
trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
inferenceDefinition -> {
// Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory
// So that the memory this model uses in the circuit breaker is the most accurate estimate.
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
if (estimateDiff < 0) {
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
try {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
handleLoadFailure(modelId, ex);
return;
}
}
handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition);
},
failure -> {
// We failed to get the definition, remove the initial estimation.
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
},
failure -> {
logger.warn(new ParameterizedMessage("[{}] failed to load model configuration", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
}
private void loadWithoutCaching(String modelId, ActionListener<Model> modelActionListener) {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
@ -181,81 +320,10 @@ public class ModelLoadingService implements ClusterStateListener {
},
modelActionListener::onFailure
));
} else {
logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
}
}
/**
* Returns true if the model is loaded and the listener has been given the cached model
* Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded
* Returns false if the model is not loaded or actively being loaded
*/
private boolean loadModelIfNecessary(String modelId, ActionListener<Model> modelActionListener) {
synchronized (loadingListeners) {
Model cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
return true;
}
// It is referenced by a pipeline, but the cache does not contain it
if (referencedModels.contains(modelId)) {
// If the loaded model is referenced there but is not present,
// that means the previous load attempt failed or the model has been evicted
// Attempt to load and cache the model if necessary
if (loadingListeners.computeIfPresent(
modelId,
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) {
logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
loadModel(modelId);
}
return true;
}
// if the cachedModel entry is null, but there are listeners present, that means it is being loaded
return loadingListeners.computeIfPresent(modelId,
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null;
} // synchronized (loadingListeners)
}
private void loadModel(String modelId) {
provider.getTrainedModel(modelId, false, ActionListener.wrap(
trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
inferenceDefinition -> {
// Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory
// So that the memory this model uses in the circuit breaker is the most accurate estimate.
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
if (estimateDiff < 0) {
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
try {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
handleLoadFailure(modelId, ex);
return;
}
}
handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition);
},
failure -> {
// We failed to get the definition, remove the initial estimation.
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
},
failure -> {
logger.warn(new ParameterizedMessage("[{}] failed to load model configuration", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
}
private void handleLoadSuccess(String modelId,
Consumer consumer,
TrainedModelConfig trainedModelConfig,
InferenceDefinition inferenceDefinition) {
Queue<ActionListener<Model>> listeners;
@ -278,7 +346,7 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed());
return;
}
localModelCache.put(modelId, loadedModel);
localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer));
shouldNotAudit.remove(modelId);
} // synchronized (loadingListeners)
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
@ -301,7 +369,7 @@ public class ModelLoadingService implements ClusterStateListener {
}
}
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
try {
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
MessageSupplier msg = () -> new ParameterizedMessage(
@ -310,15 +378,15 @@ public class ModelLoadingService implements ClusterStateListener {
"If this is undesired, consider updating setting [{}] or [{}].",
new ByteSizeValue(localModelCache.weight()).getStringRep(),
maxCacheSize.getStringRep(),
new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(),
new ByteSizeValue(notification.getValue().model.ramBytesUsed()).getStringRep(),
INFERENCE_MODEL_CACHE_SIZE.getKey(),
INFERENCE_MODEL_CACHE_TTL.getKey());
auditIfNecessary(notification.getKey(), msg);
}
// If the model is no longer referenced, flush the stats to persist as soon as possible
notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false);
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
} finally {
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed());
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed());
}
}
@ -356,7 +424,13 @@ public class ModelLoadingService implements ClusterStateListener {
removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
// Remove all cached models that are not referenced by any processors
removedModels.forEach(localModelCache::invalidate);
// and are not used in search
removedModels.forEach(modelId -> {
ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
localModelCache.invalidate(modelId);
}
});
// Remove the models that are no longer referenced
referencedModels.removeAll(removedModels);
shouldNotAudit.removeAll(removedModels);
@ -389,7 +463,7 @@ public class ModelLoadingService implements ClusterStateListener {
}
}
removedModels.forEach(this::auditUnreferencedModel);
loadModels(allReferencedModelKeys);
loadModelsForPipeline(allReferencedModelKeys);
}
private void auditIfNecessary(String modelId, MessageSupplier msg) {
@ -402,7 +476,7 @@ public class ModelLoadingService implements ClusterStateListener {
logger.info("[{}] {}", modelId, msg.get().getFormattedMessage());
}
private void loadModels(Set<String> modelIds) {
private void loadModelsForPipeline(Set<String> modelIds) {
if (modelIds.isEmpty()) {
return;
}
@ -410,7 +484,7 @@ public class ModelLoadingService implements ClusterStateListener {
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
for (String modelId : modelIds) {
auditNewReferencedModel(modelId);
this.loadModel(modelId);
this.loadModel(modelId, Consumer.PIPELINE);
}
});
}
@ -466,7 +540,7 @@ public class ModelLoadingService implements ClusterStateListener {
/**
* Register a listener for notification when a model is loaded.
*
* <p>
* This method is primarily intended for testing (hence package private)
* and shouldn't be required outside of testing.
*

View File

@ -131,7 +131,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -139,12 +139,16 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any());
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any());
assertTrue(modelLoadingService.isModelCached(model1));
assertTrue(modelLoadingService.isModelCached(model2));
assertTrue(modelLoadingService.isModelCached(model3));
// Test invalidate cache for model3
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -196,7 +200,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
String model = modelIds[i%2];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -219,7 +223,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Load model 3, should invalidate 1 and 2
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future3 = new PlainActionFuture<>();
modelLoadingService.getModel(model3, future3);
modelLoadingService.getModelForPipeline(model3, future3);
assertThat(future3.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), any());
@ -240,7 +244,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Load model 1, should invalidate 3
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future1 = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future1);
modelLoadingService.getModelForPipeline(model1, future1);
assertThat(future1.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any());
@ -254,7 +258,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
// Load model 2
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future2 = new PlainActionFuture<>();
modelLoadingService.getModel(model2, future2);
modelLoadingService.getModelForPipeline(model2, future2);
assertThat(future2.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any());
@ -265,7 +269,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
assertThat(future.get(), is(not(nullValue())));
}
@ -292,10 +296,11 @@ public class ModelLoadingServiceTests extends ESTestCase {
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future);
modelLoadingService.getModelForPipeline(model1, future);
assertThat(future.get(), is(not(nullValue())));
}
assertFalse(modelLoadingService.isModelCached(model1));
verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
@ -315,7 +320,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
modelLoadingService.clusterChanged(ingestChangedEvent(model));
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
try {
future.get();
@ -323,6 +328,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
} catch (Exception ex) {
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}
assertFalse(modelLoadingService.isModelCached(model));
verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
@ -342,13 +348,14 @@ public class ModelLoadingServiceTests extends ESTestCase {
circuitBreaker);
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
try {
future.get();
fail("Should not have succeeded");
} catch (Exception ex) {
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}
assertFalse(modelLoadingService.isModelCached(model));
}
public void testGetModelEagerly() throws Exception {
@ -366,11 +373,37 @@ public class ModelLoadingServiceTests extends ESTestCase {
for(int i = 0; i < 3; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
modelLoadingService.getModelForPipeline(model, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), any());
assertFalse(modelLoadingService.isModelCached(model));
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
public void testGetModelForSearch() throws Exception {
String modelId = "test-get-model-for-search";
withTrainedModel(modelId, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
trainedModelStatsService,
Settings.EMPTY,
"test-node",
circuitBreaker);
for(int i = 0; i < 3; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModelForSearch(modelId, future);
assertThat(future.get(), is(not(nullValue())));
}
assertTrue(modelLoadingService.isModelCached(modelId));
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(modelId), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}