From df7fc8f96794d87f21749d421b770adac68fd0e5 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 15 Jul 2020 18:06:15 +0100 Subject: [PATCH] Accounting for model size when models are not cached (#59607) When an inference model is loaded it is accounted for in circuit breaker and should not be released until there are no users of the model. Adds a reference count to the model to track usage. --- .../TransportInternalInferModelAction.java | 11 +- .../dataframe/inference/InferenceRunner.java | 7 +- .../aggs/InferencePipelineAggregator.java | 117 ++++++++++-------- .../inference/loadingservice/LocalModel.java | 57 ++++++++- .../loadingservice/ModelLoadingService.java | 87 +++++++++---- .../loadingservice/LocalModelTests.java | 78 +++++++++++- .../ModelLoadingServiceTests.java | 84 +++++++++++++ 7 files changed, 350 insertions(+), 91 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index 66593f222e1..c94c668a87b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -65,9 +65,14 @@ public class TransportInternalInferModelAction extends HandledTransportAction - listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()), - listener::onFailure + inferenceResultsInterfaces -> { + model.release(); + listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()); + }, + e -> { + model.release(); + listener.onFailure(e); + } )); }, listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java index 5322484db60..651c91b9579 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java @@ -73,10 +73,11 @@ public class InferenceRunner { LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId); try { PlainActionFuture localModelPlainActionFuture = new PlainActionFuture<>(); - modelLoadingService.getModelForSearch(modelId, localModelPlainActionFuture); - LocalModel localModel = localModelPlainActionFuture.actionGet(); + modelLoadingService.getModelForPipeline(modelId, localModelPlainActionFuture); TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config); - inferTestDocs(localModel, testDocsIterator); + try (LocalModel localModel = localModelPlainActionFuture.actionGet()) { + inferTestDocs(localModel, testDocsIterator); + } } catch (Exception e) { throw ExceptionsHelper.serverError("[{}] failed running inference on model [{}]", e, config.getId(), modelId); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java index 8a0eb04df34..1ee1e4d6e48 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java @@ -49,69 +49,76 @@ public class InferencePipelineAggregator extends PipelineAggregator { @Override public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) { - InternalMultiBucketAggregation originalAgg = - (InternalMultiBucketAggregation) aggregation; - List buckets = originalAgg.getBuckets(); + try { + InternalMultiBucketAggregation originalAgg = + (InternalMultiBucketAggregation) aggregation; + List buckets = originalAgg.getBuckets(); - List newBuckets = new ArrayList<>(); - for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) { - Map inputFields = new HashMap<>(); + List newBuckets = new ArrayList<>(); + for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) { + Map inputFields = new HashMap<>(); - if (bucket.getDocCount() == 0) { - // ignore this empty bucket unless the doc count is used - if (bucketPathMap.containsKey("_count") == false) { - newBuckets.add(bucket); - continue; - } - } - - for (Map.Entry entry : bucketPathMap.entrySet()) { - String aggName = entry.getKey(); - String bucketPath = entry.getValue(); - Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath); - - if (propertyValue instanceof Number) { - double doubleVal = ((Number) propertyValue).doubleValue(); - // NaN or infinite values indicate a missing value or a - // valid result of an invalid calculation. Either way only - // a valid number will do - if (Double.isFinite(doubleVal)) { - inputFields.put(aggName, doubleVal); + if (bucket.getDocCount() == 0) { + // ignore this empty bucket unless the doc count is used + if (bucketPathMap.containsKey("_count") == false) { + newBuckets.add(bucket); + continue; } - } else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) { - double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value(); - if (Double.isFinite(doubleVal)) { - inputFields.put(aggName, doubleVal); - } - } else if (propertyValue instanceof StringTerms.Bucket) { - StringTerms.Bucket b = (StringTerms.Bucket) propertyValue; - inputFields.put(aggName, b.getKeyAsString()); - } else if (propertyValue instanceof String) { - inputFields.put(aggName, propertyValue); - } else if (propertyValue != null) { - // Doubles, String terms or null are valid, any other type is an error - throw invalidAggTypeError(bucketPath, propertyValue); } + + for (Map.Entry entry : bucketPathMap.entrySet()) { + String aggName = entry.getKey(); + String bucketPath = entry.getValue(); + Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath); + + if (propertyValue instanceof Number) { + double doubleVal = ((Number) propertyValue).doubleValue(); + // NaN or infinite values indicate a missing value or a + // valid result of an invalid calculation. Either way only + // a valid number will do + if (Double.isFinite(doubleVal)) { + inputFields.put(aggName, doubleVal); + } + } else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) { + double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value(); + if (Double.isFinite(doubleVal)) { + inputFields.put(aggName, doubleVal); + } + } else if (propertyValue instanceof StringTerms.Bucket) { + StringTerms.Bucket b = (StringTerms.Bucket) propertyValue; + inputFields.put(aggName, b.getKeyAsString()); + } else if (propertyValue instanceof String) { + inputFields.put(aggName, propertyValue); + } else if (propertyValue != null) { + // Doubles, String terms or null are valid, any other type is an error + throw invalidAggTypeError(bucketPath, propertyValue); + } + } + + + InferenceResults inference; + try { + inference = model.infer(inputFields, configUpdate); + } catch (Exception e) { + inference = new WarningInferenceResults(e.getMessage()); + } + + final List aggs = bucket.getAggregations().asList().stream().map( + (p) -> (InternalAggregation) p).collect(Collectors.toList()); + + InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference); + aggs.add(aggResult); + InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket); + newBuckets.add(newBucket); } + // the model is released at the end of this block. + assert model.getReferenceCount() > 0; - InferenceResults inference; - try { - inference = model.infer(inputFields, configUpdate); - } catch (Exception e) { - inference = new WarningInferenceResults(e.getMessage()); - } - - final List aggs = bucket.getAggregations().asList().stream().map( - (p) -> (InternalAggregation) p).collect(Collectors.toList()); - - InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference); - aggs.add(aggResult); - InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket); - newBuckets.add(newBucket); + return originalAgg.create(newBuckets); + } finally { + model.release(); } - - return originalAgg.create(newBuckets); } public static Object resolveBucketValue(MultiBucketsAggregation agg, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index bf87427efee..ce4591a7abf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; import org.elasticsearch.license.License; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; @@ -19,16 +20,29 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MapHelper; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; +import java.io.Closeable; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; -public class LocalModel { +/** + * LocalModels implement reference counting for proper accounting in + * the {@link CircuitBreaker}. When the model is not longer used {@link #release()} + * must be called and if the reference count == 0 then the model's bytes + * will be removed from the circuit breaker. + * + * The class is constructed with an initial reference count of 1 and its + * bytes must have been added to the circuit breaker before construction. + * New references must call {@link #acquire()} and {@link #release()} as the model + * is used. + */ +public class LocalModel implements Closeable { private final InferenceDefinition trainedModelDefinition; private final String modelId; @@ -40,15 +54,18 @@ public class LocalModel { private final LongAdder currentInferenceCount; private final InferenceConfig inferenceConfig; private final License.OperationMode licenseLevel; + private final CircuitBreaker trainedModelCircuitBreaker; + private final AtomicLong referenceCount; - public LocalModel(String modelId, + LocalModel(String modelId, String nodeId, InferenceDefinition trainedModelDefinition, TrainedModelInput input, Map defaultFieldMap, InferenceConfig modelInferenceConfig, License.OperationMode licenseLevel, - TrainedModelStatsService trainedModelStatsService) { + TrainedModelStatsService trainedModelStatsService, + CircuitBreaker trainedModelCircuitBreaker) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; this.fieldNames = new HashSet<>(input.getFieldNames()); @@ -60,6 +77,8 @@ public class LocalModel { this.currentInferenceCount = new LongAdder(); this.inferenceConfig = modelInferenceConfig; this.licenseLevel = licenseLevel; + this.trainedModelCircuitBreaker = trainedModelCircuitBreaker; + this.referenceCount = new AtomicLong(1); } long ramBytesUsed() { @@ -177,4 +196,36 @@ public class LocalModel { }); } } + + long acquire() { + long count = referenceCount.incrementAndGet(); + // protect against a race where the model could be release to a + // count of zero then the model is quickly re-acquired + if (count == 1) { + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelDefinition.ramBytesUsed(), modelId); + } + return count; + } + + public long getReferenceCount() { + return referenceCount.get(); + } + + public long release() { + long count = referenceCount.decrementAndGet(); + assert count >= 0; + if (count == 0) { + // no references to this model, it no longer needs to be accounted for + trainedModelCircuitBreaker.addWithoutBreaking(-ramBytesUsed()); + } + return referenceCount.get(); + } + + /** + * Convenience method so the class can be used in try-with-resource + * constructs to invoke {@link #release()}. + */ + public void close() { + release(); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 672fd67d2ce..89b99ab9830 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -60,6 +60,12 @@ import java.util.concurrent.TimeUnit; * mode the model is not cached. All other uses will cache the model * * If more than one processor references the same model, that model will only be cached once. + * + * LocalModels are created with a reference count of 1 accounting for the reference the + * cache holds. When models are evicted from the cache the reference count is decremented. + * The {@code getModelForX} methods automatically increment the model's reference count + * it is up to the consumer to call {@link LocalModel#release()} when the model is no + * longer used. */ public class ModelLoadingService implements ClusterStateListener { @@ -93,8 +99,8 @@ public class ModelLoadingService implements ClusterStateListener { } private static class ModelAndConsumer { - private LocalModel model; - private EnumSet consumers; + private final LocalModel model; + private final EnumSet consumers; private ModelAndConsumer(LocalModel model, Consumer consumer) { this.model = model; @@ -197,6 +203,12 @@ public class ModelLoadingService implements ClusterStateListener { ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { cachedModel.consumers.add(consumer); + try { + cachedModel.model.acquire(); + } catch (CircuitBreakingException e) { + modelActionListener.onFailure(e); + return; + } modelActionListener.onResponse(cachedModel.model); logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId)); return; @@ -223,6 +235,12 @@ public class ModelLoadingService implements ClusterStateListener { ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { cachedModel.consumers.add(consumer); + try { + cachedModel.model.acquire(); + } catch (CircuitBreakingException e) { + modelActionListener.onFailure(e); + return true; + } modelActionListener.onResponse(cachedModel.model); return true; } @@ -256,20 +274,16 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); provider.getTrainedModelForInference(modelId, ActionListener.wrap( inferenceDefinition -> { - // Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory - // So that the memory this model uses in the circuit breaker is the most accurate estimate. - long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory(); - if (estimateDiff < 0) { - trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff); - } else if (estimateDiff > 0) { // rare case where estimate is now HIGHER - try { - trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId); - } catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well - trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); - handleLoadFailure(modelId, ex); - return; - } + try { + // Since we have used the previously stored estimate to help guard against OOM we need + // to adjust the memory so that the memory this model uses in the circuit breaker + // is the most accurate estimate. + updateCircuitBreakerEstimate(modelId, inferenceDefinition, trainedModelConfig); + } catch (CircuitBreakingException ex) { + handleLoadFailure(modelId, ex); + return; } + handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition); }, failure -> { @@ -300,8 +314,14 @@ public class ModelLoadingService implements ClusterStateListener { InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig(); - // Remove the bytes as we cannot control how long the caller will keep the model in memory - trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + + try { + updateCircuitBreakerEstimate(modelId, inferenceDefinition, trainedModelConfig); + } catch (CircuitBreakingException ex) { + modelActionListener.onFailure(ex); + return; + } + modelActionListener.onResponse(new LocalModel( trainedModelConfig.getModelId(), localNode, @@ -310,7 +330,8 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), - modelStatsService)); + modelStatsService, + trainedModelCircuitBreaker)); }, // Failure getting the definition, remove the initial estimation value e -> { @@ -323,6 +344,21 @@ public class ModelLoadingService implements ClusterStateListener { )); } + private void updateCircuitBreakerEstimate(String modelId, InferenceDefinition inferenceDefinition, + TrainedModelConfig trainedModelConfig) throws CircuitBreakingException { + long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory(); + if (estimateDiff < 0) { + trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff); + } else if (estimateDiff > 0) { // rare case where estimate is now HIGHER + try { + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId); + } catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + throw ex; + } + } + } + private void handleLoadSuccess(String modelId, Consumer consumer, TrainedModelConfig trainedModelConfig, @@ -339,21 +375,30 @@ public class ModelLoadingService implements ClusterStateListener { trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), - modelStatsService); + modelStatsService, + trainedModelCircuitBreaker); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such // Consequently, we should not store the retrieved model if (listeners == null) { - trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed()); + loadedModel.release(); return; } + + // temporarily increase the reference count before adding to + // the cache in case the model is evicted before the listeners + // are called in which case acquire() would throw. + loadedModel.acquire(); localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer)); shouldNotAudit.remove(modelId); } // synchronized (loadingListeners) for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + loadedModel.acquire(); listener.onResponse(loadedModel); } + // account for the acquire in the synchronized block above + loadedModel.release(); } private void handleLoadFailure(String modelId, Exception failure) { @@ -388,7 +433,7 @@ public class ModelLoadingService implements ClusterStateListener { // If the model is no longer referenced, flush the stats to persist as soon as possible notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); } finally { - trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed()); + notification.getValue().model.release(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 10eb7e7161a..8f0e2d0b06f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.license.License; import org.elasticsearch.test.ESTestCase; @@ -51,9 +52,11 @@ import static org.hamcrest.Matchers.hasSize; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.internal.verification.VerificationModeFactory.times; public class LocalModelTests extends ESTestCase { @@ -75,7 +78,8 @@ public class LocalModelTests extends ESTestCase { Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, randomFrom(License.OperationMode.values()), - modelStatsService); + modelStatsService, + mock(CircuitBreaker.class)); Map fields = new HashMap() {{ put("field.foo", 1.0); put("field", Collections.singletonMap("bar", 0.5)); @@ -105,7 +109,8 @@ public class LocalModelTests extends ESTestCase { Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, License.OperationMode.PLATINUM, - modelStatsService); + modelStatsService, + mock(CircuitBreaker.class)); result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); @@ -148,7 +153,8 @@ public class LocalModelTests extends ESTestCase { Collections.singletonMap("field.foo", "field.foo.keyword"), ClassificationConfig.EMPTY_PARAMS, License.OperationMode.PLATINUM, - modelStatsService); + modelStatsService, + mock(CircuitBreaker.class)); Map fields = new HashMap() {{ put("field.foo", 1.0); put("field.bar", 0.5); @@ -204,7 +210,8 @@ public class LocalModelTests extends ESTestCase { Collections.singletonMap("bar", "bar.keyword"), RegressionConfig.EMPTY_PARAMS, License.OperationMode.PLATINUM, - modelStatsService); + modelStatsService, + mock(CircuitBreaker.class)); Map fields = new HashMap() {{ put("foo", 1.0); @@ -232,7 +239,8 @@ public class LocalModelTests extends ESTestCase { null, RegressionConfig.EMPTY_PARAMS, License.OperationMode.PLATINUM, - modelStatsService); + modelStatsService, + mock(CircuitBreaker.class)); Map fields = new HashMap() {{ put("something", 1.0); @@ -263,7 +271,8 @@ public class LocalModelTests extends ESTestCase { null, ClassificationConfig.EMPTY_PARAMS, License.OperationMode.PLATINUM, - modelStatsService + modelStatsService, + mock(CircuitBreaker.class) ); Map fields = new HashMap() {{ put("field.foo", 1.0); @@ -309,6 +318,63 @@ public class LocalModelTests extends ESTestCase { assertThat(fields, equalTo(expectedMap)); } + public void testReferenceCounting() throws IOException { + TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); + String modelId = "ref-count-model"; + List inputFields = Arrays.asList("field.foo", "field.bar"); + InferenceDefinition definition = InferenceDefinition.builder() + .setTrainedModel(buildClassificationInference(false)) + .build(); + + { + CircuitBreaker breaker = mock(CircuitBreaker.class); + LocalModel model = new LocalModel(modelId, + "test-node", + definition, + new TrainedModelInput(inputFields), + null, + ClassificationConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, + modelStatsService, + breaker + ); + + model.release(); + verify(breaker, times(1)).addWithoutBreaking(eq(-definition.ramBytesUsed())); + verifyNoMoreInteractions(breaker); + assertEquals(0L, model.getReferenceCount()); + + // reacquire + model.acquire(); + verify(breaker, times(1)).addEstimateBytesAndMaybeBreak(eq(definition.ramBytesUsed()), eq(modelId)); + verifyNoMoreInteractions(breaker); + assertEquals(1L, model.getReferenceCount()); + } + + { + CircuitBreaker breaker = mock(CircuitBreaker.class); + LocalModel model = new LocalModel(modelId, + "test-node", + definition, + new TrainedModelInput(inputFields), + null, + ClassificationConfig.EMPTY_PARAMS, + License.OperationMode.PLATINUM, + modelStatsService, + breaker + ); + + model.acquire(); + model.acquire(); + model.release(); + model.release(); + model.release(); + verify(breaker, times(1)).addWithoutBreaking(eq(-definition.ramBytesUsed())); + verifyNoMoreInteractions(breaker); + assertEquals(0L, model.getReferenceCount()); + } + } + private static SingleValueInferenceResults getSingleValue(LocalModel model, Map fields, InferenceConfigUpdate config) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 8513d8dccfb..0a7e2475ebd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -58,6 +58,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; @@ -451,6 +452,89 @@ public class ModelLoadingServiceTests extends ESTestCase { }); } + public void testReferenceCounting() throws Exception { + String modelId = "test-reference-counting"; + withTrainedModel(modelId, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.clusterChanged(ingestChangedEvent(modelId)); + + PlainActionFuture forPipeline = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(modelId, forPipeline); + final LocalModel model = forPipeline.get(); + assertBusy(() -> assertEquals(2, model.getReferenceCount())); + + PlainActionFuture forSearch = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(modelId, forSearch); + forSearch.get(); + assertBusy(() -> assertEquals(3, model.getReferenceCount())); + + model.release(); + assertBusy(() -> assertEquals(2, model.getReferenceCount())); + + PlainActionFuture forSearch2 = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(modelId, forSearch2); + forSearch2.get(); + assertBusy(() -> assertEquals(3, model.getReferenceCount())); + } + + public void testReferenceCountingForPipeline() throws Exception { + String modelId = "test-reference-counting-for-pipeline"; + withTrainedModel(modelId, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.clusterChanged(ingestChangedEvent(modelId)); + + PlainActionFuture forPipeline = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(modelId, forPipeline); + final LocalModel model = forPipeline.get(); + assertBusy(() -> assertEquals(2, model.getReferenceCount())); + + PlainActionFuture forPipeline2 = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(modelId, forPipeline2); + forPipeline2.get(); + assertBusy(() -> assertEquals(3, model.getReferenceCount())); + + // will cause the model to be evicted + modelLoadingService.clusterChanged(ingestChangedEvent()); + assertBusy(() -> assertEquals(2, model.getReferenceCount())); + } + + public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, InterruptedException { + String modelId = "test-reference-counting-not-cached"; + withTrainedModel(modelId, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(modelId, future); + LocalModel model = future.get(); + assertEquals(1, model.getReferenceCount()); + } + @SuppressWarnings("unchecked") private void withTrainedModel(String modelId, long size) { InferenceDefinition definition = mock(InferenceDefinition.class);