ModelLoadingService only caches models if they are referenced by an ingest pipeline. For models used in search we want to always cache the models and rely on TTL to evict them. Additionally when an ingest pipeline is deleted the model it references should not be evicted if it is used in search.
This commit is contained in:
parent
2905a2f623
commit
b87b147704
|
@ -75,13 +75,13 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|||
|
||||
if (licenseState.isAllowed(XPackLicenseState.Feature.MACHINE_LEARNING)) {
|
||||
responseBuilder.setLicensed(true);
|
||||
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
|
||||
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
||||
} else {
|
||||
trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
|
||||
if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
|
||||
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
|
||||
this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
||||
} else {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
|
|
|
@ -32,7 +32,6 @@ public class LocalModel implements Model {
|
|||
|
||||
private final InferenceDefinition trainedModelDefinition;
|
||||
private final String modelId;
|
||||
private final String nodeId;
|
||||
private final Set<String> fieldNames;
|
||||
private final Map<String, String> defaultFieldMap;
|
||||
private final InferenceStats.Accumulator statsAccumulator;
|
||||
|
@ -50,7 +49,6 @@ public class LocalModel implements Model {
|
|||
TrainedModelStatsService trainedModelStatsService ) {
|
||||
this.trainedModelDefinition = trainedModelDefinition;
|
||||
this.modelId = modelId;
|
||||
this.nodeId = nodeId;
|
||||
this.fieldNames = new HashSet<>(input.getFieldNames());
|
||||
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
|
||||
this.trainedModelStatsService = trainedModelStatsService;
|
||||
|
|
|
@ -43,6 +43,7 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
|||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
@ -52,9 +53,11 @@ import java.util.Set;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* This is a thread safe model loading service.
|
||||
* This is a thread safe model loading service with LRU cache.
|
||||
* Cache entries have a TTL before they are evicted.
|
||||
*
|
||||
* It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node).
|
||||
* In the case of a pipeline processor requesting the model if the processor is in simulate
|
||||
* mode the model is not cached. All other uses will cache the model
|
||||
*
|
||||
* If more than one processor references the same model, that model will only be cached once.
|
||||
*/
|
||||
|
@ -62,7 +65,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
|
||||
/**
|
||||
* The maximum size of the local model cache here in the loading service
|
||||
*
|
||||
* <p>
|
||||
* Once the limit is reached, LRU models are evicted in favor of new models
|
||||
*/
|
||||
public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE =
|
||||
|
@ -72,12 +75,11 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
|
||||
/**
|
||||
* How long should a model stay in the cache since its last access
|
||||
*
|
||||
* <p>
|
||||
* If nothing references a model via getModel for this configured timeValue, it will be evicted.
|
||||
*
|
||||
* <p>
|
||||
* Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not
|
||||
* executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called.
|
||||
*
|
||||
*/
|
||||
public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL =
|
||||
Setting.timeSetting("xpack.ml.inference_model.time_to_live",
|
||||
|
@ -85,9 +87,25 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
new TimeValue(1, TimeUnit.MILLISECONDS),
|
||||
Setting.Property.NodeScope);
|
||||
|
||||
// The feature requesting the model
|
||||
public enum Consumer {
|
||||
PIPELINE, SEARCH
|
||||
}
|
||||
|
||||
private static class ModelAndConsumer {
|
||||
private LocalModel model;
|
||||
private EnumSet<Consumer> consumers;
|
||||
|
||||
private ModelAndConsumer(LocalModel model, Consumer consumer) {
|
||||
this.model = model;
|
||||
this.consumers = EnumSet.of(consumer);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
|
||||
private final TrainedModelStatsService modelStatsService;
|
||||
private final Cache<String, LocalModel> localModelCache;
|
||||
private final Cache<String, ModelAndConsumer> localModelCache;
|
||||
private final Set<String> referencedModels = new HashSet<>();
|
||||
private final Map<String, Queue<ActionListener<Model>>> loadingListeners = new HashMap<>();
|
||||
private final TrainedModelProvider provider;
|
||||
|
@ -112,9 +130,9 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
this.auditor = auditor;
|
||||
this.modelStatsService = modelStatsService;
|
||||
this.shouldNotAudit = new HashSet<>();
|
||||
this.localModelCache = CacheBuilder.<String, LocalModel>builder()
|
||||
this.localModelCache = CacheBuilder.<String, ModelAndConsumer>builder()
|
||||
.setMaximumWeight(this.maxCacheSize.getBytes())
|
||||
.weigher((id, localModel) -> localModel.ramBytesUsed())
|
||||
.weigher((id, modelAndConsumer) -> modelAndConsumer.model.ramBytesUsed())
|
||||
// explicit declaration of the listener lambda necessary for Eclipse IDE 4.14
|
||||
.removalListener(notification -> cacheEvictionListener(notification))
|
||||
.setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings))
|
||||
|
@ -124,101 +142,115 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker");
|
||||
}
|
||||
|
||||
boolean isModelCached(String modelId) {
|
||||
return localModelCache.get(modelId) != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model for use by an ingest pipeline. The model will not be cached if there is no
|
||||
* ingest pipeline referencing it i.e. it is used in simulate mode
|
||||
*
|
||||
* @param modelId the model to get
|
||||
* @param modelActionListener the listener to alert when the model has been retrieved
|
||||
*/
|
||||
public void getModelForPipeline(String modelId, ActionListener<Model> modelActionListener) {
|
||||
getModel(modelId, Consumer.PIPELINE, modelActionListener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model for use by at search. Models requested by search are always cached.
|
||||
*
|
||||
* @param modelId the model to get
|
||||
* @param modelActionListener the listener to alert when the model has been retrieved
|
||||
*/
|
||||
public void getModelForSearch(String modelId, ActionListener<Model> modelActionListener) {
|
||||
getModel(modelId, Consumer.SEARCH, modelActionListener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the model referenced by `modelId` and responds to the listener.
|
||||
*
|
||||
* <p>
|
||||
* This method first checks the local LRU cache for the model. If it is present, it is returned from cache.
|
||||
* <p>
|
||||
* In the case of search if the model is not present one of the following occurs:
|
||||
* - If it is currently being loaded the `modelActionListener`
|
||||
* is added to the list of listeners to be alerted when the model is fully loaded.
|
||||
* - Otherwise the model is loaded and cached
|
||||
*
|
||||
* If it is not present, one of the following occurs:
|
||||
* In the case of an ingest processor if it is not present, one of the following occurs:
|
||||
* <p>
|
||||
* - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener`
|
||||
* is added to the list of listeners to be alerted when the model is fully loaded.
|
||||
* - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting
|
||||
* model will attempt to be cached for future reference
|
||||
* - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener.
|
||||
* It is not cached.
|
||||
*
|
||||
* - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener`
|
||||
* is added to the list of listeners to be alerted when the model is fully loaded.
|
||||
* - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting
|
||||
* model will attempt to be cached for future reference
|
||||
* - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener.
|
||||
* It is not cached.
|
||||
* The main difference being that models for search are always cached whereas pipeline models
|
||||
* are only cached if they are referenced by an ingest pipeline
|
||||
*
|
||||
* @param modelId the model to get
|
||||
* @param modelId the model to get
|
||||
* @param consumer which feature is requesting the model
|
||||
* @param modelActionListener the listener to alert when the model has been retrieved.
|
||||
*/
|
||||
public void getModel(String modelId, ActionListener<Model> modelActionListener) {
|
||||
LocalModel cachedModel = localModelCache.get(modelId);
|
||||
private void getModel(String modelId, Consumer consumer, ActionListener<Model> modelActionListener) {
|
||||
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
modelActionListener.onResponse(cachedModel);
|
||||
cachedModel.consumers.add(consumer);
|
||||
modelActionListener.onResponse(cachedModel.model);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
|
||||
return;
|
||||
}
|
||||
if (loadModelIfNecessary(modelId, modelActionListener) == false) {
|
||||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
// Verify we can pull the model into memory without causing OOM
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
inferenceDefinition -> {
|
||||
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
|
||||
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
|
||||
trainedModelConfig.getInferenceConfig();
|
||||
// Remove the bytes as we cannot control how long the caller will keep the model in memory
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
modelActionListener.onResponse(new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
inferenceDefinition,
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
modelStatsService));
|
||||
},
|
||||
// Failure getting the definition, remove the initial estimation value
|
||||
e -> {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
modelActionListener.onFailure(e);
|
||||
}
|
||||
));
|
||||
},
|
||||
modelActionListener::onFailure
|
||||
));
|
||||
} else {
|
||||
|
||||
if (loadModelIfNecessary(modelId, consumer, modelActionListener)) {
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the model is loaded and the listener has been given the cached model
|
||||
* Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded
|
||||
* Returns false if the model is not loaded or actively being loaded
|
||||
* If the model is cached it is returned directly to the listener
|
||||
* else if the model is CURRENTLY being loaded the listener is added to be notified when it is loaded
|
||||
* else the model load is initiated.
|
||||
*
|
||||
* @param modelId The model to get
|
||||
* @param consumer The model consumer
|
||||
* @param modelActionListener The listener
|
||||
* @return If the model is cached or currently being loaded true is returned. If a new load is started
|
||||
* false is returned to indicate a new load event
|
||||
*/
|
||||
private boolean loadModelIfNecessary(String modelId, ActionListener<Model> modelActionListener) {
|
||||
private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener<Model> modelActionListener) {
|
||||
synchronized (loadingListeners) {
|
||||
Model cachedModel = localModelCache.get(modelId);
|
||||
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
modelActionListener.onResponse(cachedModel);
|
||||
cachedModel.consumers.add(consumer);
|
||||
modelActionListener.onResponse(cachedModel.model);
|
||||
return true;
|
||||
}
|
||||
// It is referenced by a pipeline, but the cache does not contain it
|
||||
if (referencedModels.contains(modelId)) {
|
||||
// If the loaded model is referenced there but is not present,
|
||||
// that means the previous load attempt failed or the model has been evicted
|
||||
// Attempt to load and cache the model if necessary
|
||||
if (loadingListeners.computeIfPresent(
|
||||
modelId,
|
||||
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) {
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
|
||||
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
|
||||
loadModel(modelId);
|
||||
}
|
||||
|
||||
// Add the listener to the queue if the model is loading
|
||||
Queue<ActionListener<Model>> listeners = loadingListeners.computeIfPresent(modelId,
|
||||
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener));
|
||||
|
||||
// The cachedModel entry is null, but there are listeners present, that means it is being loaded
|
||||
if (listeners != null) {
|
||||
return true;
|
||||
}
|
||||
// if the cachedModel entry is null, but there are listeners present, that means it is being loaded
|
||||
return loadingListeners.computeIfPresent(modelId,
|
||||
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null;
|
||||
|
||||
if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) {
|
||||
// The model is requested by a pipeline but not referenced by any ingest pipelines.
|
||||
// This means it is a simulate call and the model should not be cached
|
||||
loadWithoutCaching(modelId, modelActionListener);
|
||||
} else {
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
|
||||
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
|
||||
loadModel(modelId, consumer);
|
||||
}
|
||||
|
||||
return false;
|
||||
} // synchronized (loadingListeners)
|
||||
}
|
||||
|
||||
private void loadModel(String modelId) {
|
||||
private void loadModel(String modelId, Consumer consumer) {
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
|
@ -238,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
return;
|
||||
}
|
||||
}
|
||||
handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition);
|
||||
handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition);
|
||||
},
|
||||
failure -> {
|
||||
// We failed to get the definition, remove the initial estimation.
|
||||
|
@ -255,7 +287,43 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
));
|
||||
}
|
||||
|
||||
private void loadWithoutCaching(String modelId, ActionListener<Model> modelActionListener) {
|
||||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
// Verify we can pull the model into memory without causing OOM
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
inferenceDefinition -> {
|
||||
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
|
||||
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
|
||||
trainedModelConfig.getInferenceConfig();
|
||||
// Remove the bytes as we cannot control how long the caller will keep the model in memory
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
modelActionListener.onResponse(new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
inferenceDefinition,
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
modelStatsService));
|
||||
},
|
||||
// Failure getting the definition, remove the initial estimation value
|
||||
e -> {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
modelActionListener.onFailure(e);
|
||||
}
|
||||
));
|
||||
},
|
||||
modelActionListener::onFailure
|
||||
));
|
||||
}
|
||||
|
||||
private void handleLoadSuccess(String modelId,
|
||||
Consumer consumer,
|
||||
TrainedModelConfig trainedModelConfig,
|
||||
InferenceDefinition inferenceDefinition) {
|
||||
Queue<ActionListener<Model>> listeners;
|
||||
|
@ -278,7 +346,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed());
|
||||
return;
|
||||
}
|
||||
localModelCache.put(modelId, loadedModel);
|
||||
localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer));
|
||||
shouldNotAudit.remove(modelId);
|
||||
} // synchronized (loadingListeners)
|
||||
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
|
||||
|
@ -301,7 +369,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
}
|
||||
|
||||
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
|
||||
private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
|
||||
try {
|
||||
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
|
||||
MessageSupplier msg = () -> new ParameterizedMessage(
|
||||
|
@ -310,15 +378,15 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
"If this is undesired, consider updating setting [{}] or [{}].",
|
||||
new ByteSizeValue(localModelCache.weight()).getStringRep(),
|
||||
maxCacheSize.getStringRep(),
|
||||
new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(),
|
||||
new ByteSizeValue(notification.getValue().model.ramBytesUsed()).getStringRep(),
|
||||
INFERENCE_MODEL_CACHE_SIZE.getKey(),
|
||||
INFERENCE_MODEL_CACHE_TTL.getKey());
|
||||
auditIfNecessary(notification.getKey(), msg);
|
||||
}
|
||||
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
||||
notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||
} finally {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed());
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -356,7 +424,13 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
|
||||
|
||||
// Remove all cached models that are not referenced by any processors
|
||||
removedModels.forEach(localModelCache::invalidate);
|
||||
// and are not used in search
|
||||
removedModels.forEach(modelId -> {
|
||||
ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
|
||||
if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
|
||||
localModelCache.invalidate(modelId);
|
||||
}
|
||||
});
|
||||
// Remove the models that are no longer referenced
|
||||
referencedModels.removeAll(removedModels);
|
||||
shouldNotAudit.removeAll(removedModels);
|
||||
|
@ -389,7 +463,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
}
|
||||
removedModels.forEach(this::auditUnreferencedModel);
|
||||
loadModels(allReferencedModelKeys);
|
||||
loadModelsForPipeline(allReferencedModelKeys);
|
||||
}
|
||||
|
||||
private void auditIfNecessary(String modelId, MessageSupplier msg) {
|
||||
|
@ -402,7 +476,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
logger.info("[{}] {}", modelId, msg.get().getFormattedMessage());
|
||||
}
|
||||
|
||||
private void loadModels(Set<String> modelIds) {
|
||||
private void loadModelsForPipeline(Set<String> modelIds) {
|
||||
if (modelIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
@ -410,7 +484,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
|
||||
for (String modelId : modelIds) {
|
||||
auditNewReferencedModel(modelId);
|
||||
this.loadModel(modelId);
|
||||
this.loadModel(modelId, Consumer.PIPELINE);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -436,11 +510,11 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
|
||||
Object processors = pipelineConfiguration.getConfigAsMap().get("processors");
|
||||
if (processors instanceof List<?>) {
|
||||
for(Object processor : (List<?>)processors) {
|
||||
for (Object processor : (List<?>) processors) {
|
||||
if (processor instanceof Map<?, ?>) {
|
||||
Object processorConfig = ((Map<?, ?>)processor).get(InferenceProcessor.TYPE);
|
||||
Object processorConfig = ((Map<?, ?>) processor).get(InferenceProcessor.TYPE);
|
||||
if (processorConfig instanceof Map<?, ?>) {
|
||||
Object modelId = ((Map<?, ?>)processorConfig).get(InferenceProcessor.MODEL_ID);
|
||||
Object modelId = ((Map<?, ?>) processorConfig).get(InferenceProcessor.MODEL_ID);
|
||||
if (modelId != null) {
|
||||
assert modelId instanceof String;
|
||||
allReferencedModelKeys.add(modelId.toString());
|
||||
|
@ -454,7 +528,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
|
||||
private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) {
|
||||
switch(targetType) {
|
||||
switch (targetType) {
|
||||
case REGRESSION:
|
||||
return RegressionConfig.EMPTY_PARAMS;
|
||||
case CLASSIFICATION:
|
||||
|
@ -466,22 +540,22 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
|
||||
/**
|
||||
* Register a listener for notification when a model is loaded.
|
||||
*
|
||||
* <p>
|
||||
* This method is primarily intended for testing (hence package private)
|
||||
* and shouldn't be required outside of testing.
|
||||
*
|
||||
* @param modelId Model Id
|
||||
* @param modelId Model Id
|
||||
* @param modelLoadedListener To be notified
|
||||
*/
|
||||
void addModelLoadedListener(String modelId, ActionListener<Model> modelLoadedListener) {
|
||||
synchronized (loadingListeners) {
|
||||
loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
|
||||
if (listenerQueue == null) {
|
||||
return addFluently(new ArrayDeque<>(), modelLoadedListener);
|
||||
} else {
|
||||
return addFluently(listenerQueue, modelLoadedListener);
|
||||
}
|
||||
});
|
||||
if (listenerQueue == null) {
|
||||
return addFluently(new ArrayDeque<>(), modelLoadedListener);
|
||||
} else {
|
||||
return addFluently(listenerQueue, modelLoadedListener);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -131,7 +131,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
for(int i = 0; i < 10; i++) {
|
||||
String model = modelIds[i%3];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
|
@ -139,12 +139,16 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any());
|
||||
|
||||
assertTrue(modelLoadingService.isModelCached(model1));
|
||||
assertTrue(modelLoadingService.isModelCached(model2));
|
||||
assertTrue(modelLoadingService.isModelCached(model3));
|
||||
|
||||
// Test invalidate cache for model3
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
|
||||
for(int i = 0; i < 10; i++) {
|
||||
String model = modelIds[i%3];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
|
@ -196,7 +200,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
|
||||
String model = modelIds[i%2];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
|
@ -219,7 +223,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
// Load model 3, should invalidate 1 and 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future3 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model3, future3);
|
||||
modelLoadingService.getModelForPipeline(model3, future3);
|
||||
assertThat(future3.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), any());
|
||||
|
@ -240,7 +244,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
// Load model 1, should invalidate 3
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future1 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model1, future1);
|
||||
modelLoadingService.getModelForPipeline(model1, future1);
|
||||
assertThat(future1.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any());
|
||||
|
@ -254,7 +258,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
// Load model 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future2 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model2, future2);
|
||||
modelLoadingService.getModelForPipeline(model2, future2);
|
||||
assertThat(future2.get(), is(not(nullValue())));
|
||||
}
|
||||
verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any());
|
||||
|
@ -265,7 +269,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
for(int i = 0; i < 10; i++) {
|
||||
String model = modelIds[i%3];
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
|
@ -292,10 +296,11 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
|
||||
for(int i = 0; i < 10; i++) {
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model1, future);
|
||||
modelLoadingService.getModelForPipeline(model1, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
assertFalse(modelLoadingService.isModelCached(model1));
|
||||
verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
@ -315,7 +320,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
modelLoadingService.clusterChanged(ingestChangedEvent(model));
|
||||
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
|
||||
try {
|
||||
future.get();
|
||||
|
@ -323,6 +328,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
} catch (Exception ex) {
|
||||
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
|
||||
}
|
||||
assertFalse(modelLoadingService.isModelCached(model));
|
||||
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
|
@ -342,13 +348,14 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
circuitBreaker);
|
||||
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
try {
|
||||
future.get();
|
||||
fail("Should not have succeeded");
|
||||
} catch (Exception ex) {
|
||||
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
|
||||
}
|
||||
assertFalse(modelLoadingService.isModelCached(model));
|
||||
}
|
||||
|
||||
public void testGetModelEagerly() throws Exception {
|
||||
|
@ -366,11 +373,37 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
modelLoadingService.getModelForPipeline(model, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), any());
|
||||
assertFalse(modelLoadingService.isModelCached(model));
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
||||
public void testGetModelForSearch() throws Exception {
|
||||
String modelId = "test-get-model-for-search";
|
||||
withTrainedModel(modelId, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForSearch(modelId, future);
|
||||
assertThat(future.get(), is(not(nullValue())));
|
||||
}
|
||||
|
||||
assertTrue(modelLoadingService.isModelCached(modelId));
|
||||
|
||||
verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(modelId), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue