Accounting for model size when models are not cached (#59607)

When an inference model is loaded it is accounted for in circuit breaker
and should not be released until there are no users of the model. Adds
a reference count to the model to track usage.
This commit is contained in:
David Kyle 2020-07-15 18:06:15 +01:00 committed by GitHub
parent ef9b14b07e
commit df7fc8f967
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 350 additions and 91 deletions

View File

@ -65,9 +65,14 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
model.infer(stringObjectMap, request.getUpdate(), chainedTask)));
typedChainTaskExecutor.execute(ActionListener.wrap(
inferenceResultsInterfaces ->
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()),
listener::onFailure
inferenceResultsInterfaces -> {
model.release();
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build());
},
e -> {
model.release();
listener.onFailure(e);
}
));
},
listener::onFailure

View File

@ -73,10 +73,11 @@ public class InferenceRunner {
LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId);
try {
PlainActionFuture<LocalModel> localModelPlainActionFuture = new PlainActionFuture<>();
modelLoadingService.getModelForSearch(modelId, localModelPlainActionFuture);
LocalModel localModel = localModelPlainActionFuture.actionGet();
modelLoadingService.getModelForPipeline(modelId, localModelPlainActionFuture);
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config);
inferTestDocs(localModel, testDocsIterator);
try (LocalModel localModel = localModelPlainActionFuture.actionGet()) {
inferTestDocs(localModel, testDocsIterator);
}
} catch (Exception e) {
throw ExceptionsHelper.serverError("[{}] failed running inference on model [{}]", e, config.getId(), modelId);
}

View File

@ -49,69 +49,76 @@ public class InferencePipelineAggregator extends PipelineAggregator {
@Override
public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) {
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
try {
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
Map<String, Object> inputFields = new HashMap<>();
List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
Map<String, Object> inputFields = new HashMap<>();
if (bucket.getDocCount() == 0) {
// ignore this empty bucket unless the doc count is used
if (bucketPathMap.containsKey("_count") == false) {
newBuckets.add(bucket);
continue;
}
}
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
String aggName = entry.getKey();
String bucketPath = entry.getValue();
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
if (propertyValue instanceof Number) {
double doubleVal = ((Number) propertyValue).doubleValue();
// NaN or infinite values indicate a missing value or a
// valid result of an invalid calculation. Either way only
// a valid number will do
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
if (bucket.getDocCount() == 0) {
// ignore this empty bucket unless the doc count is used
if (bucketPathMap.containsKey("_count") == false) {
newBuckets.add(bucket);
continue;
}
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof StringTerms.Bucket) {
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
inputFields.put(aggName, b.getKeyAsString());
} else if (propertyValue instanceof String) {
inputFields.put(aggName, propertyValue);
} else if (propertyValue != null) {
// Doubles, String terms or null are valid, any other type is an error
throw invalidAggTypeError(bucketPath, propertyValue);
}
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
String aggName = entry.getKey();
String bucketPath = entry.getValue();
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
if (propertyValue instanceof Number) {
double doubleVal = ((Number) propertyValue).doubleValue();
// NaN or infinite values indicate a missing value or a
// valid result of an invalid calculation. Either way only
// a valid number will do
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
if (Double.isFinite(doubleVal)) {
inputFields.put(aggName, doubleVal);
}
} else if (propertyValue instanceof StringTerms.Bucket) {
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
inputFields.put(aggName, b.getKeyAsString());
} else if (propertyValue instanceof String) {
inputFields.put(aggName, propertyValue);
} else if (propertyValue != null) {
// Doubles, String terms or null are valid, any other type is an error
throw invalidAggTypeError(bucketPath, propertyValue);
}
}
InferenceResults inference;
try {
inference = model.infer(inputFields, configUpdate);
} catch (Exception e) {
inference = new WarningInferenceResults(e.getMessage());
}
final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
(p) -> (InternalAggregation) p).collect(Collectors.toList());
InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
aggs.add(aggResult);
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
newBuckets.add(newBucket);
}
// the model is released at the end of this block.
assert model.getReferenceCount() > 0;
InferenceResults inference;
try {
inference = model.infer(inputFields, configUpdate);
} catch (Exception e) {
inference = new WarningInferenceResults(e.getMessage());
}
final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
(p) -> (InternalAggregation) p).collect(Collectors.toList());
InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
aggs.add(aggResult);
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
newBuckets.add(newBucket);
return originalAgg.create(newBuckets);
} finally {
model.release();
}
return originalAgg.create(newBuckets);
}
public static Object resolveBucketValue(MultiBucketsAggregation agg,

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.license.License;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@ -19,16 +20,29 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import java.io.Closeable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
public class LocalModel {
/**
* LocalModels implement reference counting for proper accounting in
* the {@link CircuitBreaker}. When the model is not longer used {@link #release()}
* must be called and if the reference count == 0 then the model's bytes
* will be removed from the circuit breaker.
*
* The class is constructed with an initial reference count of 1 and its
* bytes <em>must</em> have been added to the circuit breaker before construction.
* New references must call {@link #acquire()} and {@link #release()} as the model
* is used.
*/
public class LocalModel implements Closeable {
private final InferenceDefinition trainedModelDefinition;
private final String modelId;
@ -40,15 +54,18 @@ public class LocalModel {
private final LongAdder currentInferenceCount;
private final InferenceConfig inferenceConfig;
private final License.OperationMode licenseLevel;
private final CircuitBreaker trainedModelCircuitBreaker;
private final AtomicLong referenceCount;
public LocalModel(String modelId,
LocalModel(String modelId,
String nodeId,
InferenceDefinition trainedModelDefinition,
TrainedModelInput input,
Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig,
License.OperationMode licenseLevel,
TrainedModelStatsService trainedModelStatsService) {
TrainedModelStatsService trainedModelStatsService,
CircuitBreaker trainedModelCircuitBreaker) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
this.fieldNames = new HashSet<>(input.getFieldNames());
@ -60,6 +77,8 @@ public class LocalModel {
this.currentInferenceCount = new LongAdder();
this.inferenceConfig = modelInferenceConfig;
this.licenseLevel = licenseLevel;
this.trainedModelCircuitBreaker = trainedModelCircuitBreaker;
this.referenceCount = new AtomicLong(1);
}
long ramBytesUsed() {
@ -177,4 +196,36 @@ public class LocalModel {
});
}
}
long acquire() {
long count = referenceCount.incrementAndGet();
// protect against a race where the model could be release to a
// count of zero then the model is quickly re-acquired
if (count == 1) {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelDefinition.ramBytesUsed(), modelId);
}
return count;
}
public long getReferenceCount() {
return referenceCount.get();
}
public long release() {
long count = referenceCount.decrementAndGet();
assert count >= 0;
if (count == 0) {
// no references to this model, it no longer needs to be accounted for
trainedModelCircuitBreaker.addWithoutBreaking(-ramBytesUsed());
}
return referenceCount.get();
}
/**
* Convenience method so the class can be used in try-with-resource
* constructs to invoke {@link #release()}.
*/
public void close() {
release();
}
}

View File

@ -60,6 +60,12 @@ import java.util.concurrent.TimeUnit;
* mode the model is not cached. All other uses will cache the model
*
* If more than one processor references the same model, that model will only be cached once.
*
* LocalModels are created with a reference count of 1 accounting for the reference the
* cache holds. When models are evicted from the cache the reference count is decremented.
* The {@code getModelForX} methods automatically increment the model's reference count
* it is up to the consumer to call {@link LocalModel#release()} when the model is no
* longer used.
*/
public class ModelLoadingService implements ClusterStateListener {
@ -93,8 +99,8 @@ public class ModelLoadingService implements ClusterStateListener {
}
private static class ModelAndConsumer {
private LocalModel model;
private EnumSet<Consumer> consumers;
private final LocalModel model;
private final EnumSet<Consumer> consumers;
private ModelAndConsumer(LocalModel model, Consumer consumer) {
this.model = model;
@ -197,6 +203,12 @@ public class ModelLoadingService implements ClusterStateListener {
ModelAndConsumer cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
cachedModel.consumers.add(consumer);
try {
cachedModel.model.acquire();
} catch (CircuitBreakingException e) {
modelActionListener.onFailure(e);
return;
}
modelActionListener.onResponse(cachedModel.model);
logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
return;
@ -223,6 +235,12 @@ public class ModelLoadingService implements ClusterStateListener {
ModelAndConsumer cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
cachedModel.consumers.add(consumer);
try {
cachedModel.model.acquire();
} catch (CircuitBreakingException e) {
modelActionListener.onFailure(e);
return true;
}
modelActionListener.onResponse(cachedModel.model);
return true;
}
@ -256,20 +274,16 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
inferenceDefinition -> {
// Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory
// So that the memory this model uses in the circuit breaker is the most accurate estimate.
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
if (estimateDiff < 0) {
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
try {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
handleLoadFailure(modelId, ex);
return;
}
try {
// Since we have used the previously stored estimate to help guard against OOM we need
// to adjust the memory so that the memory this model uses in the circuit breaker
// is the most accurate estimate.
updateCircuitBreakerEstimate(modelId, inferenceDefinition, trainedModelConfig);
} catch (CircuitBreakingException ex) {
handleLoadFailure(modelId, ex);
return;
}
handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition);
},
failure -> {
@ -300,8 +314,14 @@ public class ModelLoadingService implements ClusterStateListener {
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
trainedModelConfig.getInferenceConfig();
// Remove the bytes as we cannot control how long the caller will keep the model in memory
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
try {
updateCircuitBreakerEstimate(modelId, inferenceDefinition, trainedModelConfig);
} catch (CircuitBreakingException ex) {
modelActionListener.onFailure(ex);
return;
}
modelActionListener.onResponse(new LocalModel(
trainedModelConfig.getModelId(),
localNode,
@ -310,7 +330,8 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService));
modelStatsService,
trainedModelCircuitBreaker));
},
// Failure getting the definition, remove the initial estimation value
e -> {
@ -323,6 +344,21 @@ public class ModelLoadingService implements ClusterStateListener {
));
}
private void updateCircuitBreakerEstimate(String modelId, InferenceDefinition inferenceDefinition,
TrainedModelConfig trainedModelConfig) throws CircuitBreakingException {
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
if (estimateDiff < 0) {
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
try {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
throw ex;
}
}
}
private void handleLoadSuccess(String modelId,
Consumer consumer,
TrainedModelConfig trainedModelConfig,
@ -339,21 +375,30 @@ public class ModelLoadingService implements ClusterStateListener {
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService);
modelStatsService,
trainedModelCircuitBreaker);
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
// Consequently, we should not store the retrieved model
if (listeners == null) {
trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed());
loadedModel.release();
return;
}
// temporarily increase the reference count before adding to
// the cache in case the model is evicted before the listeners
// are called in which case acquire() would throw.
loadedModel.acquire();
localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer));
shouldNotAudit.remove(modelId);
} // synchronized (loadingListeners)
for (ActionListener<LocalModel> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
loadedModel.acquire();
listener.onResponse(loadedModel);
}
// account for the acquire in the synchronized block above
loadedModel.release();
}
private void handleLoadFailure(String modelId, Exception failure) {
@ -388,7 +433,7 @@ public class ModelLoadingService implements ClusterStateListener {
// If the model is no longer referenced, flush the stats to persist as soon as possible
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
} finally {
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed());
notification.getValue().model.release();
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.license.License;
import org.elasticsearch.test.ESTestCase;
@ -51,9 +52,11 @@ import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.internal.verification.VerificationModeFactory.times;
public class LocalModelTests extends ESTestCase {
@ -75,7 +78,8 @@ public class LocalModelTests extends ESTestCase {
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
randomFrom(License.OperationMode.values()),
modelStatsService);
modelStatsService,
mock(CircuitBreaker.class));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field", Collections.singletonMap("bar", 0.5));
@ -105,7 +109,8 @@ public class LocalModelTests extends ESTestCase {
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
modelStatsService,
mock(CircuitBreaker.class));
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
@ -148,7 +153,8 @@ public class LocalModelTests extends ESTestCase {
Collections.singletonMap("field.foo", "field.foo.keyword"),
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
modelStatsService,
mock(CircuitBreaker.class));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
put("field.bar", 0.5);
@ -204,7 +210,8 @@ public class LocalModelTests extends ESTestCase {
Collections.singletonMap("bar", "bar.keyword"),
RegressionConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
modelStatsService,
mock(CircuitBreaker.class));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
@ -232,7 +239,8 @@ public class LocalModelTests extends ESTestCase {
null,
RegressionConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService);
modelStatsService,
mock(CircuitBreaker.class));
Map<String, Object> fields = new HashMap<String, Object>() {{
put("something", 1.0);
@ -263,7 +271,8 @@ public class LocalModelTests extends ESTestCase {
null,
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService
modelStatsService,
mock(CircuitBreaker.class)
);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("field.foo", 1.0);
@ -309,6 +318,63 @@ public class LocalModelTests extends ESTestCase {
assertThat(fields, equalTo(expectedMap));
}
public void testReferenceCounting() throws IOException {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
String modelId = "ref-count-model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar");
InferenceDefinition definition = InferenceDefinition.builder()
.setTrainedModel(buildClassificationInference(false))
.build();
{
CircuitBreaker breaker = mock(CircuitBreaker.class);
LocalModel model = new LocalModel(modelId,
"test-node",
definition,
new TrainedModelInput(inputFields),
null,
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService,
breaker
);
model.release();
verify(breaker, times(1)).addWithoutBreaking(eq(-definition.ramBytesUsed()));
verifyNoMoreInteractions(breaker);
assertEquals(0L, model.getReferenceCount());
// reacquire
model.acquire();
verify(breaker, times(1)).addEstimateBytesAndMaybeBreak(eq(definition.ramBytesUsed()), eq(modelId));
verifyNoMoreInteractions(breaker);
assertEquals(1L, model.getReferenceCount());
}
{
CircuitBreaker breaker = mock(CircuitBreaker.class);
LocalModel model = new LocalModel(modelId,
"test-node",
definition,
new TrainedModelInput(inputFields),
null,
ClassificationConfig.EMPTY_PARAMS,
License.OperationMode.PLATINUM,
modelStatsService,
breaker
);
model.acquire();
model.acquire();
model.release();
model.release();
model.release();
verify(breaker, times(1)).addWithoutBreaking(eq(-definition.ramBytesUsed()));
verifyNoMoreInteractions(breaker);
assertEquals(0L, model.getReferenceCount());
}
}
private static SingleValueInferenceResults getSingleValue(LocalModel model,
Map<String, Object> fields,
InferenceConfigUpdate config)

View File

@ -58,6 +58,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
@ -451,6 +452,89 @@ public class ModelLoadingServiceTests extends ESTestCase {
});
}
public void testReferenceCounting() throws Exception {
String modelId = "test-reference-counting";
withTrainedModel(modelId, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
trainedModelStatsService,
Settings.EMPTY,
"test-node",
circuitBreaker);
modelLoadingService.clusterChanged(ingestChangedEvent(modelId));
PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, forPipeline);
final LocalModel model = forPipeline.get();
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
PlainActionFuture<LocalModel> forSearch = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, forSearch);
forSearch.get();
assertBusy(() -> assertEquals(3, model.getReferenceCount()));
model.release();
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
PlainActionFuture<LocalModel> forSearch2 = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, forSearch2);
forSearch2.get();
assertBusy(() -> assertEquals(3, model.getReferenceCount()));
}
public void testReferenceCountingForPipeline() throws Exception {
String modelId = "test-reference-counting-for-pipeline";
withTrainedModel(modelId, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
trainedModelStatsService,
Settings.EMPTY,
"test-node",
circuitBreaker);
modelLoadingService.clusterChanged(ingestChangedEvent(modelId));
PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, forPipeline);
final LocalModel model = forPipeline.get();
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
PlainActionFuture<LocalModel> forPipeline2 = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, forPipeline2);
forPipeline2.get();
assertBusy(() -> assertEquals(3, model.getReferenceCount()));
// will cause the model to be evicted
modelLoadingService.clusterChanged(ingestChangedEvent());
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
}
public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, InterruptedException {
String modelId = "test-reference-counting-not-cached";
withTrainedModel(modelId, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
trainedModelStatsService,
Settings.EMPTY,
"test-node",
circuitBreaker);
PlainActionFuture<LocalModel> future = new PlainActionFuture<>();
modelLoadingService.getModelForPipeline(modelId, future);
LocalModel model = future.get();
assertEquals(1, model.getReferenceCount());
}
@SuppressWarnings("unchecked")
private void withTrainedModel(String modelId, long size) {
InferenceDefinition definition = mock(InferenceDefinition.class);