Accounting for model size when models are not cached (#59607)
When an inference model is loaded it is accounted for in circuit breaker and should not be released until there are no users of the model. Adds a reference count to the model to track usage.
This commit is contained in:
parent
ef9b14b07e
commit
df7fc8f967
|
@ -65,9 +65,14 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|||
model.infer(stringObjectMap, request.getUpdate(), chainedTask)));
|
||||
|
||||
typedChainTaskExecutor.execute(ActionListener.wrap(
|
||||
inferenceResultsInterfaces ->
|
||||
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()),
|
||||
listener::onFailure
|
||||
inferenceResultsInterfaces -> {
|
||||
model.release();
|
||||
listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build());
|
||||
},
|
||||
e -> {
|
||||
model.release();
|
||||
listener.onFailure(e);
|
||||
}
|
||||
));
|
||||
},
|
||||
listener::onFailure
|
||||
|
|
|
@ -73,10 +73,11 @@ public class InferenceRunner {
|
|||
LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId);
|
||||
try {
|
||||
PlainActionFuture<LocalModel> localModelPlainActionFuture = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForSearch(modelId, localModelPlainActionFuture);
|
||||
LocalModel localModel = localModelPlainActionFuture.actionGet();
|
||||
modelLoadingService.getModelForPipeline(modelId, localModelPlainActionFuture);
|
||||
TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config);
|
||||
inferTestDocs(localModel, testDocsIterator);
|
||||
try (LocalModel localModel = localModelPlainActionFuture.actionGet()) {
|
||||
inferTestDocs(localModel, testDocsIterator);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw ExceptionsHelper.serverError("[{}] failed running inference on model [{}]", e, config.getId(), modelId);
|
||||
}
|
||||
|
|
|
@ -49,69 +49,76 @@ public class InferencePipelineAggregator extends PipelineAggregator {
|
|||
@Override
|
||||
public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) {
|
||||
|
||||
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
|
||||
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
|
||||
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
|
||||
try {
|
||||
InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
|
||||
(InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
|
||||
List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
|
||||
|
||||
List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
|
||||
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
|
||||
Map<String, Object> inputFields = new HashMap<>();
|
||||
List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
|
||||
for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
|
||||
Map<String, Object> inputFields = new HashMap<>();
|
||||
|
||||
if (bucket.getDocCount() == 0) {
|
||||
// ignore this empty bucket unless the doc count is used
|
||||
if (bucketPathMap.containsKey("_count") == false) {
|
||||
newBuckets.add(bucket);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
|
||||
String aggName = entry.getKey();
|
||||
String bucketPath = entry.getValue();
|
||||
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
|
||||
|
||||
if (propertyValue instanceof Number) {
|
||||
double doubleVal = ((Number) propertyValue).doubleValue();
|
||||
// NaN or infinite values indicate a missing value or a
|
||||
// valid result of an invalid calculation. Either way only
|
||||
// a valid number will do
|
||||
if (Double.isFinite(doubleVal)) {
|
||||
inputFields.put(aggName, doubleVal);
|
||||
if (bucket.getDocCount() == 0) {
|
||||
// ignore this empty bucket unless the doc count is used
|
||||
if (bucketPathMap.containsKey("_count") == false) {
|
||||
newBuckets.add(bucket);
|
||||
continue;
|
||||
}
|
||||
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
|
||||
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
|
||||
if (Double.isFinite(doubleVal)) {
|
||||
inputFields.put(aggName, doubleVal);
|
||||
}
|
||||
} else if (propertyValue instanceof StringTerms.Bucket) {
|
||||
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
|
||||
inputFields.put(aggName, b.getKeyAsString());
|
||||
} else if (propertyValue instanceof String) {
|
||||
inputFields.put(aggName, propertyValue);
|
||||
} else if (propertyValue != null) {
|
||||
// Doubles, String terms or null are valid, any other type is an error
|
||||
throw invalidAggTypeError(bucketPath, propertyValue);
|
||||
}
|
||||
|
||||
for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
|
||||
String aggName = entry.getKey();
|
||||
String bucketPath = entry.getValue();
|
||||
Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
|
||||
|
||||
if (propertyValue instanceof Number) {
|
||||
double doubleVal = ((Number) propertyValue).doubleValue();
|
||||
// NaN or infinite values indicate a missing value or a
|
||||
// valid result of an invalid calculation. Either way only
|
||||
// a valid number will do
|
||||
if (Double.isFinite(doubleVal)) {
|
||||
inputFields.put(aggName, doubleVal);
|
||||
}
|
||||
} else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
|
||||
double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
|
||||
if (Double.isFinite(doubleVal)) {
|
||||
inputFields.put(aggName, doubleVal);
|
||||
}
|
||||
} else if (propertyValue instanceof StringTerms.Bucket) {
|
||||
StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
|
||||
inputFields.put(aggName, b.getKeyAsString());
|
||||
} else if (propertyValue instanceof String) {
|
||||
inputFields.put(aggName, propertyValue);
|
||||
} else if (propertyValue != null) {
|
||||
// Doubles, String terms or null are valid, any other type is an error
|
||||
throw invalidAggTypeError(bucketPath, propertyValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
InferenceResults inference;
|
||||
try {
|
||||
inference = model.infer(inputFields, configUpdate);
|
||||
} catch (Exception e) {
|
||||
inference = new WarningInferenceResults(e.getMessage());
|
||||
}
|
||||
|
||||
final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
|
||||
(p) -> (InternalAggregation) p).collect(Collectors.toList());
|
||||
|
||||
InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
|
||||
aggs.add(aggResult);
|
||||
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
|
||||
newBuckets.add(newBucket);
|
||||
}
|
||||
|
||||
// the model is released at the end of this block.
|
||||
assert model.getReferenceCount() > 0;
|
||||
|
||||
InferenceResults inference;
|
||||
try {
|
||||
inference = model.infer(inputFields, configUpdate);
|
||||
} catch (Exception e) {
|
||||
inference = new WarningInferenceResults(e.getMessage());
|
||||
}
|
||||
|
||||
final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
|
||||
(p) -> (InternalAggregation) p).collect(Collectors.toList());
|
||||
|
||||
InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
|
||||
aggs.add(aggResult);
|
||||
InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
|
||||
newBuckets.add(newBucket);
|
||||
return originalAgg.create(newBuckets);
|
||||
} finally {
|
||||
model.release();
|
||||
}
|
||||
|
||||
return originalAgg.create(newBuckets);
|
||||
}
|
||||
|
||||
public static Object resolveBucketValue(MultiBucketsAggregation agg,
|
||||
|
|
|
@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
|
|||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
|
||||
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
|
||||
|
@ -19,16 +20,29 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|||
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
|
||||
|
||||
public class LocalModel {
|
||||
/**
|
||||
* LocalModels implement reference counting for proper accounting in
|
||||
* the {@link CircuitBreaker}. When the model is not longer used {@link #release()}
|
||||
* must be called and if the reference count == 0 then the model's bytes
|
||||
* will be removed from the circuit breaker.
|
||||
*
|
||||
* The class is constructed with an initial reference count of 1 and its
|
||||
* bytes <em>must</em> have been added to the circuit breaker before construction.
|
||||
* New references must call {@link #acquire()} and {@link #release()} as the model
|
||||
* is used.
|
||||
*/
|
||||
public class LocalModel implements Closeable {
|
||||
|
||||
private final InferenceDefinition trainedModelDefinition;
|
||||
private final String modelId;
|
||||
|
@ -40,15 +54,18 @@ public class LocalModel {
|
|||
private final LongAdder currentInferenceCount;
|
||||
private final InferenceConfig inferenceConfig;
|
||||
private final License.OperationMode licenseLevel;
|
||||
private final CircuitBreaker trainedModelCircuitBreaker;
|
||||
private final AtomicLong referenceCount;
|
||||
|
||||
public LocalModel(String modelId,
|
||||
LocalModel(String modelId,
|
||||
String nodeId,
|
||||
InferenceDefinition trainedModelDefinition,
|
||||
TrainedModelInput input,
|
||||
Map<String, String> defaultFieldMap,
|
||||
InferenceConfig modelInferenceConfig,
|
||||
License.OperationMode licenseLevel,
|
||||
TrainedModelStatsService trainedModelStatsService) {
|
||||
TrainedModelStatsService trainedModelStatsService,
|
||||
CircuitBreaker trainedModelCircuitBreaker) {
|
||||
this.trainedModelDefinition = trainedModelDefinition;
|
||||
this.modelId = modelId;
|
||||
this.fieldNames = new HashSet<>(input.getFieldNames());
|
||||
|
@ -60,6 +77,8 @@ public class LocalModel {
|
|||
this.currentInferenceCount = new LongAdder();
|
||||
this.inferenceConfig = modelInferenceConfig;
|
||||
this.licenseLevel = licenseLevel;
|
||||
this.trainedModelCircuitBreaker = trainedModelCircuitBreaker;
|
||||
this.referenceCount = new AtomicLong(1);
|
||||
}
|
||||
|
||||
long ramBytesUsed() {
|
||||
|
@ -177,4 +196,36 @@ public class LocalModel {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
long acquire() {
|
||||
long count = referenceCount.incrementAndGet();
|
||||
// protect against a race where the model could be release to a
|
||||
// count of zero then the model is quickly re-acquired
|
||||
if (count == 1) {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelDefinition.ramBytesUsed(), modelId);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
public long getReferenceCount() {
|
||||
return referenceCount.get();
|
||||
}
|
||||
|
||||
public long release() {
|
||||
long count = referenceCount.decrementAndGet();
|
||||
assert count >= 0;
|
||||
if (count == 0) {
|
||||
// no references to this model, it no longer needs to be accounted for
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-ramBytesUsed());
|
||||
}
|
||||
return referenceCount.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience method so the class can be used in try-with-resource
|
||||
* constructs to invoke {@link #release()}.
|
||||
*/
|
||||
public void close() {
|
||||
release();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,6 +60,12 @@ import java.util.concurrent.TimeUnit;
|
|||
* mode the model is not cached. All other uses will cache the model
|
||||
*
|
||||
* If more than one processor references the same model, that model will only be cached once.
|
||||
*
|
||||
* LocalModels are created with a reference count of 1 accounting for the reference the
|
||||
* cache holds. When models are evicted from the cache the reference count is decremented.
|
||||
* The {@code getModelForX} methods automatically increment the model's reference count
|
||||
* it is up to the consumer to call {@link LocalModel#release()} when the model is no
|
||||
* longer used.
|
||||
*/
|
||||
public class ModelLoadingService implements ClusterStateListener {
|
||||
|
||||
|
@ -93,8 +99,8 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
|
||||
private static class ModelAndConsumer {
|
||||
private LocalModel model;
|
||||
private EnumSet<Consumer> consumers;
|
||||
private final LocalModel model;
|
||||
private final EnumSet<Consumer> consumers;
|
||||
|
||||
private ModelAndConsumer(LocalModel model, Consumer consumer) {
|
||||
this.model = model;
|
||||
|
@ -197,6 +203,12 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
cachedModel.consumers.add(consumer);
|
||||
try {
|
||||
cachedModel.model.acquire();
|
||||
} catch (CircuitBreakingException e) {
|
||||
modelActionListener.onFailure(e);
|
||||
return;
|
||||
}
|
||||
modelActionListener.onResponse(cachedModel.model);
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
|
||||
return;
|
||||
|
@ -223,6 +235,12 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
ModelAndConsumer cachedModel = localModelCache.get(modelId);
|
||||
if (cachedModel != null) {
|
||||
cachedModel.consumers.add(consumer);
|
||||
try {
|
||||
cachedModel.model.acquire();
|
||||
} catch (CircuitBreakingException e) {
|
||||
modelActionListener.onFailure(e);
|
||||
return true;
|
||||
}
|
||||
modelActionListener.onResponse(cachedModel.model);
|
||||
return true;
|
||||
}
|
||||
|
@ -256,20 +274,16 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
inferenceDefinition -> {
|
||||
// Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory
|
||||
// So that the memory this model uses in the circuit breaker is the most accurate estimate.
|
||||
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
|
||||
if (estimateDiff < 0) {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
|
||||
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
|
||||
try {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
|
||||
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
handleLoadFailure(modelId, ex);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
// Since we have used the previously stored estimate to help guard against OOM we need
|
||||
// to adjust the memory so that the memory this model uses in the circuit breaker
|
||||
// is the most accurate estimate.
|
||||
updateCircuitBreakerEstimate(modelId, inferenceDefinition, trainedModelConfig);
|
||||
} catch (CircuitBreakingException ex) {
|
||||
handleLoadFailure(modelId, ex);
|
||||
return;
|
||||
}
|
||||
|
||||
handleLoadSuccess(modelId, consumer, trainedModelConfig, inferenceDefinition);
|
||||
},
|
||||
failure -> {
|
||||
|
@ -300,8 +314,14 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
|
||||
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
|
||||
trainedModelConfig.getInferenceConfig();
|
||||
// Remove the bytes as we cannot control how long the caller will keep the model in memory
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
|
||||
try {
|
||||
updateCircuitBreakerEstimate(modelId, inferenceDefinition, trainedModelConfig);
|
||||
} catch (CircuitBreakingException ex) {
|
||||
modelActionListener.onFailure(ex);
|
||||
return;
|
||||
}
|
||||
|
||||
modelActionListener.onResponse(new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
|
@ -310,7 +330,8 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
trainedModelConfig.getLicenseLevel(),
|
||||
modelStatsService));
|
||||
modelStatsService,
|
||||
trainedModelCircuitBreaker));
|
||||
},
|
||||
// Failure getting the definition, remove the initial estimation value
|
||||
e -> {
|
||||
|
@ -323,6 +344,21 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
));
|
||||
}
|
||||
|
||||
private void updateCircuitBreakerEstimate(String modelId, InferenceDefinition inferenceDefinition,
|
||||
TrainedModelConfig trainedModelConfig) throws CircuitBreakingException {
|
||||
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
|
||||
if (estimateDiff < 0) {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
|
||||
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
|
||||
try {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
|
||||
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void handleLoadSuccess(String modelId,
|
||||
Consumer consumer,
|
||||
TrainedModelConfig trainedModelConfig,
|
||||
|
@ -339,21 +375,30 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
trainedModelConfig.getLicenseLevel(),
|
||||
modelStatsService);
|
||||
modelStatsService,
|
||||
trainedModelCircuitBreaker);
|
||||
synchronized (loadingListeners) {
|
||||
listeners = loadingListeners.remove(modelId);
|
||||
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
|
||||
// Consequently, we should not store the retrieved model
|
||||
if (listeners == null) {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed());
|
||||
loadedModel.release();
|
||||
return;
|
||||
}
|
||||
|
||||
// temporarily increase the reference count before adding to
|
||||
// the cache in case the model is evicted before the listeners
|
||||
// are called in which case acquire() would throw.
|
||||
loadedModel.acquire();
|
||||
localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer));
|
||||
shouldNotAudit.remove(modelId);
|
||||
} // synchronized (loadingListeners)
|
||||
for (ActionListener<LocalModel> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
|
||||
loadedModel.acquire();
|
||||
listener.onResponse(loadedModel);
|
||||
}
|
||||
// account for the acquire in the synchronized block above
|
||||
loadedModel.release();
|
||||
}
|
||||
|
||||
private void handleLoadFailure(String modelId, Exception failure) {
|
||||
|
@ -388,7 +433,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
||||
notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||
} finally {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().model.ramBytesUsed());
|
||||
notification.getValue().model.release();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
package org.elasticsearch.xpack.ml.inference.loadingservice;
|
||||
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.ingest.IngestDocument;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
@ -51,9 +52,11 @@ import static org.hamcrest.Matchers.hasSize;
|
|||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyBoolean;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.internal.verification.VerificationModeFactory.times;
|
||||
|
||||
public class LocalModelTests extends ESTestCase {
|
||||
|
@ -75,7 +78,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
randomFrom(License.OperationMode.values()),
|
||||
modelStatsService);
|
||||
modelStatsService,
|
||||
mock(CircuitBreaker.class));
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
put("field", Collections.singletonMap("bar", 0.5));
|
||||
|
@ -105,7 +109,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
modelStatsService,
|
||||
mock(CircuitBreaker.class));
|
||||
result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
|
||||
assertThat(result.value(), equalTo(0.0));
|
||||
assertThat(result.valueAsString(), equalTo("not_to_be"));
|
||||
|
@ -148,7 +153,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
Collections.singletonMap("field.foo", "field.foo.keyword"),
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
modelStatsService,
|
||||
mock(CircuitBreaker.class));
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
put("field.bar", 0.5);
|
||||
|
@ -204,7 +210,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
Collections.singletonMap("bar", "bar.keyword"),
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
modelStatsService,
|
||||
mock(CircuitBreaker.class));
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("foo", 1.0);
|
||||
|
@ -232,7 +239,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
null,
|
||||
RegressionConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService);
|
||||
modelStatsService,
|
||||
mock(CircuitBreaker.class));
|
||||
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("something", 1.0);
|
||||
|
@ -263,7 +271,8 @@ public class LocalModelTests extends ESTestCase {
|
|||
null,
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService
|
||||
modelStatsService,
|
||||
mock(CircuitBreaker.class)
|
||||
);
|
||||
Map<String, Object> fields = new HashMap<String, Object>() {{
|
||||
put("field.foo", 1.0);
|
||||
|
@ -309,6 +318,63 @@ public class LocalModelTests extends ESTestCase {
|
|||
assertThat(fields, equalTo(expectedMap));
|
||||
}
|
||||
|
||||
public void testReferenceCounting() throws IOException {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
String modelId = "ref-count-model";
|
||||
List<String> inputFields = Arrays.asList("field.foo", "field.bar");
|
||||
InferenceDefinition definition = InferenceDefinition.builder()
|
||||
.setTrainedModel(buildClassificationInference(false))
|
||||
.build();
|
||||
|
||||
{
|
||||
CircuitBreaker breaker = mock(CircuitBreaker.class);
|
||||
LocalModel model = new LocalModel(modelId,
|
||||
"test-node",
|
||||
definition,
|
||||
new TrainedModelInput(inputFields),
|
||||
null,
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService,
|
||||
breaker
|
||||
);
|
||||
|
||||
model.release();
|
||||
verify(breaker, times(1)).addWithoutBreaking(eq(-definition.ramBytesUsed()));
|
||||
verifyNoMoreInteractions(breaker);
|
||||
assertEquals(0L, model.getReferenceCount());
|
||||
|
||||
// reacquire
|
||||
model.acquire();
|
||||
verify(breaker, times(1)).addEstimateBytesAndMaybeBreak(eq(definition.ramBytesUsed()), eq(modelId));
|
||||
verifyNoMoreInteractions(breaker);
|
||||
assertEquals(1L, model.getReferenceCount());
|
||||
}
|
||||
|
||||
{
|
||||
CircuitBreaker breaker = mock(CircuitBreaker.class);
|
||||
LocalModel model = new LocalModel(modelId,
|
||||
"test-node",
|
||||
definition,
|
||||
new TrainedModelInput(inputFields),
|
||||
null,
|
||||
ClassificationConfig.EMPTY_PARAMS,
|
||||
License.OperationMode.PLATINUM,
|
||||
modelStatsService,
|
||||
breaker
|
||||
);
|
||||
|
||||
model.acquire();
|
||||
model.acquire();
|
||||
model.release();
|
||||
model.release();
|
||||
model.release();
|
||||
verify(breaker, times(1)).addWithoutBreaking(eq(-definition.ramBytesUsed()));
|
||||
verifyNoMoreInteractions(breaker);
|
||||
assertEquals(0L, model.getReferenceCount());
|
||||
}
|
||||
}
|
||||
|
||||
private static SingleValueInferenceResults getSingleValue(LocalModel model,
|
||||
Map<String, Object> fields,
|
||||
InferenceConfigUpdate config)
|
||||
|
|
|
@ -58,6 +58,7 @@ import java.util.HashMap;
|
|||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
||||
|
@ -451,6 +452,89 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
});
|
||||
}
|
||||
|
||||
public void testReferenceCounting() throws Exception {
|
||||
String modelId = "test-reference-counting";
|
||||
withTrainedModel(modelId, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(modelId));
|
||||
|
||||
PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForPipeline(modelId, forPipeline);
|
||||
final LocalModel model = forPipeline.get();
|
||||
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
||||
|
||||
PlainActionFuture<LocalModel> forSearch = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForPipeline(modelId, forSearch);
|
||||
forSearch.get();
|
||||
assertBusy(() -> assertEquals(3, model.getReferenceCount()));
|
||||
|
||||
model.release();
|
||||
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
||||
|
||||
PlainActionFuture<LocalModel> forSearch2 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForPipeline(modelId, forSearch2);
|
||||
forSearch2.get();
|
||||
assertBusy(() -> assertEquals(3, model.getReferenceCount()));
|
||||
}
|
||||
|
||||
public void testReferenceCountingForPipeline() throws Exception {
|
||||
String modelId = "test-reference-counting-for-pipeline";
|
||||
withTrainedModel(modelId, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(modelId));
|
||||
|
||||
PlainActionFuture<LocalModel> forPipeline = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForPipeline(modelId, forPipeline);
|
||||
final LocalModel model = forPipeline.get();
|
||||
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
||||
|
||||
PlainActionFuture<LocalModel> forPipeline2 = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForPipeline(modelId, forPipeline2);
|
||||
forPipeline2.get();
|
||||
assertBusy(() -> assertEquals(3, model.getReferenceCount()));
|
||||
|
||||
// will cause the model to be evicted
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent());
|
||||
assertBusy(() -> assertEquals(2, model.getReferenceCount()));
|
||||
}
|
||||
|
||||
public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, InterruptedException {
|
||||
String modelId = "test-reference-counting-not-cached";
|
||||
withTrainedModel(modelId, 1L);
|
||||
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
PlainActionFuture<LocalModel> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModelForPipeline(modelId, future);
|
||||
LocalModel model = future.get();
|
||||
assertEquals(1, model.getReferenceCount());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void withTrainedModel(String modelId, long size) {
|
||||
InferenceDefinition definition = mock(InferenceDefinition.class);
|
||||
|
|
Loading…
Reference in New Issue