From d5522c2747674f1079b42bdcfb627a0bc3a7008b Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 8 Jun 2020 16:02:48 -0400 Subject: [PATCH] [ML] add new circuit breaker for inference model caching (#57731) (#57830) This adds new plugin level circuit breaker for the ML plugin. `model_inference` is the circuit breaker qualified name. Right now it simply adds to the breaker when the model is loaded (and possibly breaking) and removing from the breaker when the model is unloaded. --- docs/reference/settings/ml-settings.asciidoc | 22 +++ .../xpack/ml/MachineLearning.java | 38 +++- .../loadingservice/ModelLoadingService.java | 123 ++++++++---- .../persistence/TrainedModelProvider.java | 92 ++++----- .../xpack/ml/LocalStateMachineLearning.java | 9 +- .../ModelLoadingServiceTests.java | 184 ++++++++++++++++-- .../xpack/ml/support/BaseMlIntegTestCase.java | 3 + 7 files changed, 367 insertions(+), 104 deletions(-) diff --git a/docs/reference/settings/ml-settings.asciidoc b/docs/reference/settings/ml-settings.asciidoc index 7cdeba856f2..73bfcd36463 100644 --- a/docs/reference/settings/ml-settings.asciidoc +++ b/docs/reference/settings/ml-settings.asciidoc @@ -56,6 +56,7 @@ The maximum inference cache size allowed. The inference cache exists in the JVM heap on each ingest node. The cache affords faster processing times for the `inference` processor. The value can be a static byte sized value (i.e. "2gb") or a percentage of total allocated heap. The default is "40%". +See also <>. `xpack.ml.inference_model.time_to_live`:: The time to live (TTL) for models in the inference model cache. The TTL is @@ -137,3 +138,24 @@ to the {es} JVM. When such processes are started they must connect to the {es} JVM. If such a process does not connect within the time period specified by this setting then the process is assumed to have failed. Defaults to `10s`. The minimum value for this setting is `5s`. + +[[model-inference-circuit-breaker]] +==== {ml-cap} circuit breaker settings + +`breaker.model_inference.limit` (<>) +Limit for model inference breaker, defaults to 50% of JVM heap. +If the parent circuit breaker is less than 50% of JVM heap, it is bound +to that limit instead. +See <>. + +`breaker.model_inference.overhead` (<>) +A constant that all accounting estimations are multiplied with to determine +a final estimation. Defaults to 1. +See <>. + +`breaker.model_inference.type` +The underlying type of the circuit breaker. There are two valid options: +`noop`, meaning the circuit breaker does nothing to prevent too much memory usage, +`memory`, meaning the circuit breaker tracks the memory used by inference models and +could potentially break and prevent OutOfMemory errors. +The default is `memory`. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index bd037714991..1372db30e5b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Module; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; @@ -40,12 +41,15 @@ import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; +import org.elasticsearch.indices.breaker.BreakerSettings; import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.monitor.os.OsProbe; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.CircuitBreakerPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; @@ -323,7 +327,11 @@ import java.util.function.UnaryOperator; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements SystemIndexPlugin, + AnalysisPlugin, + CircuitBreakerPlugin, + IngestPlugin, + PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -331,6 +339,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys public static final String JOB_COMMS_THREAD_POOL_NAME = NAME + "_job_comms"; public static final String UTILITY_THREAD_POOL_NAME = NAME + "_utility"; + public static final String TRAINED_MODEL_CIRCUIT_BREAKER_NAME = "model_inference"; + + private static final long DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT = (long)((0.50) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()); + private static final double DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD = 1.0D; // This is for performance testing. It's not exposed to the end user. // Recompile if you want to compare performance with C++ tokenization. public static final boolean CATEGORIZATION_TOKENIZATION_IN_JAVA = true; @@ -436,6 +448,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys private final SetOnce dataFrameAnalyticsAuditor = new SetOnce<>(); private final SetOnce memoryTracker = new SetOnce<>(); private final SetOnce mlUpgradeModeActionFilter = new SetOnce<>(); + private final SetOnce inferenceModelBreaker = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -661,10 +674,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys inferenceAuditor, threadPool, clusterService, - xContentRegistry, trainedModelStatsService, settings, - clusterService.getNodeName()); + clusterService.getNodeName(), + inferenceModelBreaker.get()); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, @@ -1001,4 +1014,23 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys new SystemIndexDescriptor(InferenceIndexConstants.INDEX_PATTERN, "Contains ML model configuration and statistics") )); } + + @Override + public BreakerSettings getCircuitBreaker(Settings settings) { + return BreakerSettings.updateFromSettings( + new BreakerSettings( + TRAINED_MODEL_CIRCUIT_BREAKER_NAME, + DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT, + DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD, + CircuitBreaker.Type.MEMORY, + CircuitBreaker.Durability.TRANSIENT + ), + settings); + } + + @Override + public void setCircuitBreaker(CircuitBreaker circuitBreaker) { + assert circuitBreaker.getName().equals(TRAINED_MODEL_CIRCUIT_BREAKER_NAME); + this.inferenceModelBreaker.set(circuitBreaker); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 7784ac097da..436cdfd54f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -15,6 +15,8 @@ import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.common.cache.RemovalNotification; @@ -24,7 +26,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -95,15 +96,16 @@ public class ModelLoadingService implements ClusterStateListener { private final InferenceAuditor auditor; private final ByteSizeValue maxCacheSize; private final String localNode; + private final CircuitBreaker trainedModelCircuitBreaker; public ModelLoadingService(TrainedModelProvider trainedModelProvider, InferenceAuditor auditor, ThreadPool threadPool, ClusterService clusterService, - NamedXContentRegistry namedXContentRegistry, TrainedModelStatsService modelStatsService, Settings settings, - String localNode) { + String localNode, + CircuitBreaker trainedModelCircuitBreaker) { this.provider = trainedModelProvider; this.threadPool = threadPool; this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings); @@ -119,6 +121,7 @@ public class ModelLoadingService implements ClusterStateListener { .build(); clusterService.addListener(this); this.localNode = localNode; + this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker"); } /** @@ -149,21 +152,32 @@ public class ModelLoadingService implements ClusterStateListener { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); - provider.getTrainedModelForInference(modelId, ActionListener.wrap( - configAndInferenceDef -> { - TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1(); - InferenceDefinition inferenceDefinition = configAndInferenceDef.v2(); - InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? - inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : - trainedModelConfig.getInferenceConfig(); - modelActionListener.onResponse(new LocalModel( - trainedModelConfig.getModelId(), - localNode, - inferenceDefinition, - trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap(), - inferenceConfig, - modelStatsService)); + provider.getTrainedModel(modelId, false, ActionListener.wrap( + trainedModelConfig -> { + // Verify we can pull the model into memory without causing OOM + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); + provider.getTrainedModelForInference(modelId, ActionListener.wrap( + inferenceDefinition -> { + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : + trainedModelConfig.getInferenceConfig(); + // Remove the bytes as we cannot control how long the caller will keep the model in memory + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onResponse(new LocalModel( + trainedModelConfig.getModelId(), + localNode, + inferenceDefinition, + trainedModelConfig.getInput(), + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig, + modelStatsService)); + }, + // Failure getting the definition, remove the initial estimation value + e -> { + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onFailure(e); + } + )); }, modelActionListener::onFailure )); @@ -205,29 +219,53 @@ public class ModelLoadingService implements ClusterStateListener { } private void loadModel(String modelId) { - provider.getTrainedModelForInference(modelId, ActionListener.wrap( - configAndInferenceDef -> { - logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId)); - handleLoadSuccess(modelId, configAndInferenceDef); + provider.getTrainedModel(modelId, false, ActionListener.wrap( + trainedModelConfig -> { + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); + provider.getTrainedModelForInference(modelId, ActionListener.wrap( + inferenceDefinition -> { + // Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory + // So that the memory this model uses in the circuit breaker is the most accurate estimate. + long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory(); + if (estimateDiff < 0) { + trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff); + } else if (estimateDiff > 0) { // rare case where estimate is now HIGHER + try { + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId); + } catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + handleLoadFailure(modelId, ex); + return; + } + } + handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition); + }, + failure -> { + // We failed to get the definition, remove the initial estimation. + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure); + handleLoadFailure(modelId, failure); + } + )); }, failure -> { - logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure); + logger.warn(new ParameterizedMessage("[{}] failed to load model configuration", modelId), failure); handleLoadFailure(modelId, failure); } )); } private void handleLoadSuccess(String modelId, - Tuple configAndInferenceDef) { + TrainedModelConfig trainedModelConfig, + InferenceDefinition inferenceDefinition) { Queue> listeners; - TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1(); InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? - inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : + inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig(); LocalModel loadedModel = new LocalModel( trainedModelConfig.getModelId(), localNode, - configAndInferenceDef.v2(), + inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, @@ -237,6 +275,7 @@ public class ModelLoadingService implements ClusterStateListener { // If there is no loadingListener that means the loading was canceled and the listener was already notified as such // Consequently, we should not store the retrieved model if (listeners == null) { + trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed()); return; } localModelCache.put(modelId, loadedModel); @@ -263,20 +302,24 @@ public class ModelLoadingService implements ClusterStateListener { } private void cacheEvictionListener(RemovalNotification notification) { - if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { - MessageSupplier msg = () -> new ParameterizedMessage( - "model cache entry evicted." + - "current cache [{}] current max [{}] model size [{}]. " + - "If this is undesired, consider updating setting [{}] or [{}].", - new ByteSizeValue(localModelCache.weight()).getStringRep(), - maxCacheSize.getStringRep(), - new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), - INFERENCE_MODEL_CACHE_SIZE.getKey(), - INFERENCE_MODEL_CACHE_TTL.getKey()); - auditIfNecessary(notification.getKey(), msg); + try { + if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { + MessageSupplier msg = () -> new ParameterizedMessage( + "model cache entry evicted." + + "current cache [{}] current max [{}] model size [{}]. " + + "If this is undesired, consider updating setting [{}] or [{}].", + new ByteSizeValue(localModelCache.weight()).getStringRep(), + maxCacheSize.getStringRep(), + new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), + INFERENCE_MODEL_CACHE_SIZE.getKey(), + INFERENCE_MODEL_CACHE_TTL.getKey()); + auditIfNecessary(notification.getKey(), msg); + } + // If the model is no longer referenced, flush the stats to persist as soon as possible + notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); + } finally { + trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed()); } - // If the model is no longer referenced, flush the stats to persist as soon as possible - notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 284ed4ea67d..5585fe09dc8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -232,19 +232,17 @@ public class TrainedModelProvider { executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener); } - public void getTrainedModelForInference(final String modelId, - final ActionListener> listener) { + public void getTrainedModelForInference(final String modelId, final ActionListener listener) { // TODO Change this when we get more than just langIdent stored if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry); assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; - listener.onResponse(Tuple.tuple( - config, + listener.onResponse( InferenceDefinition.builder() .setPreProcessors(config.getModelDefinition().getPreProcessors()) .setTrainedModel((LangIdentNeuralNetwork)config.getModelDefinition().getTrainedModel()) - .build())); + .build()); return; } catch (ElasticsearchException|IOException ex) { listener.onFailure(ex); @@ -252,46 +250,52 @@ public class TrainedModelProvider { } } - getTrainedModel(modelId, false, ActionListener.wrap( - config -> { - SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) - .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders - .boolQuery() - .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) - .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), - TrainedModelDefinitionDoc.NAME)))) - .setSize(MAX_NUM_DEFINITION_DOCS) - // First find the latest index - .addSort("_index", SortOrder.DESC) - // Then, sort by doc_num - .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) - .order(SortOrder.ASC) - .unmappedType("long")) - .request(); - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( - searchResponse -> { - List docs = handleHits(searchResponse.getHits().getHits(), - modelId, - this::parseModelDefinitionDocLenientlyFromSource); - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; - } - InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, - InferenceDefinition::fromXContent, - xContentRegistry); - listener.onResponse(Tuple.tuple(config, inferenceDefinition)); - }, - listener::onFailure - )); - + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders + .boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), + TrainedModelDefinitionDoc.NAME)))) + .setSize(MAX_NUM_DEFINITION_DOCS) + // First find the latest index + .addSort("_index", SortOrder.DESC) + // Then, sort by doc_num + .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) + .order(SortOrder.ASC) + .unmappedType("long")) + .request(); + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return; + } + List docs = handleHits(searchResponse.getHits().getHits(), + modelId, + this::parseModelDefinitionDocLenientlyFromSource); + String compressedString = docs.stream() + .map(TrainedModelDefinitionDoc::getCompressedString) + .collect(Collectors.joining()); + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + listener.onFailure(ExceptionsHelper.serverError( + Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + return; + } + InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( + compressedString, + InferenceDefinition::fromXContent, + xContentRegistry); + listener.onResponse(inferenceDefinition); }, - listener::onFailure + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return; + } + listener.onFailure(e); + } )); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index f1c1144d1f3..8d3299c828d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.license.LicenseService; @@ -29,18 +30,22 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; + public class LocalStateMachineLearning extends LocalStateCompositeXPackPlugin { public LocalStateMachineLearning(final Settings settings, final Path configPath) throws Exception { super(settings, configPath); LocalStateMachineLearning thisVar = this; + MachineLearning plugin = new MachineLearning(settings, configPath){ - plugins.add(new MachineLearning(settings, configPath){ @Override protected XPackLicenseState getLicenseState() { return thisVar.getLicenseState(); } - }); + }; + plugin.setCircuitBreaker(new NoopCircuitBreaker(TRAINED_MODEL_CIRCUIT_BREAKER_NAME)); + plugins.add(plugin); plugins.add(new Monitoring(settings) { @Override protected SSLService getSslService() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index c525989e7f1..98f38bcd52c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -18,13 +18,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -61,6 +62,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -83,6 +85,7 @@ public class ModelLoadingServiceTests extends ESTestCase { private ClusterService clusterService; private InferenceAuditor auditor; private TrainedModelStatsService trainedModelStatsService; + private CircuitBreaker circuitBreaker; @Before public void setUpComponents() { @@ -97,6 +100,7 @@ public class ModelLoadingServiceTests extends ESTestCase { doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class)); doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); + circuitBreaker = new CustomCircuitBreaker(1000); } @After @@ -116,10 +120,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); @@ -163,10 +167,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(), - "test-node"); + "test-node", + circuitBreaker); // We want to be notified when the models are loaded which happens in a background thread ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds)); @@ -279,10 +283,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); @@ -304,10 +308,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); modelLoadingService.clusterChanged(ingestChangedEvent(model)); PlainActionFuture future = new PlainActionFuture<>(); @@ -332,10 +336,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); @@ -355,10 +359,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); for(int i = 0; i < 3; i++) { PlainActionFuture future = new PlainActionFuture<>(); @@ -370,6 +374,50 @@ public class ModelLoadingServiceTests extends ESTestCase { verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } + public void testCircuitBreakerBreak() throws Exception { + String model1 = "test-circuit-break-model-1"; + String model2 = "test-circuit-break-model-2"; + String model3 = "test-circuit-break-model-3"; + withTrainedModel(model1, 5L); + withTrainedModel(model2, 5L); + withTrainedModel(model3, 12L); + CircuitBreaker circuitBreaker = new CustomCircuitBreaker(11); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.addModelLoadedListener(model3, ActionListener.wrap( + r -> fail("Should not have succeeded to load model as breaker should be reached"), + e -> assertThat(e, instanceOf(CircuitBreakingException.class)) + )); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + // Should have been loaded from the cluster change event but it is unknown in what order + // the loading occurred or which models are currently in the cache due to evictions. + // Verify that we have at least loaded all three + assertBusy(() -> { + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any()); + }); + assertBusy(() -> { + assertThat(circuitBreaker.getUsed(), equalTo(10L)); + assertThat(circuitBreaker.getTrippedCount(), equalTo(1L)); + }); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1)); + + assertBusy(() -> { + assertThat(circuitBreaker.getUsed(), equalTo(5L)); + }); + } + @SuppressWarnings("unchecked") private void withTrainedModel(String modelId, long size) { InferenceDefinition definition = mock(InferenceDefinition.class); @@ -378,15 +426,48 @@ public class ModelLoadingServiceTests extends ESTestCase { when(trainedModelConfig.getModelId()).thenReturn(modelId); when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS); when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); + when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - listener.onResponse(Tuple.tuple(trainedModelConfig, definition)); + listener.onResponse(definition); return null; }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); } + @SuppressWarnings("unchecked") private void withMissingModel(String modelId) { + if (randomBoolean()) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + } else { + TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); + when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return null; + }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + } doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -438,6 +519,79 @@ public class ModelLoadingServiceTests extends ESTestCase { } } + private static class CustomCircuitBreaker implements CircuitBreaker { + + private final long maxBytes; + private long currentBytes = 0; + private long trippedCount = 0; + + CustomCircuitBreaker(long maxBytes) { + this.maxBytes = maxBytes; + } + + @Override + public void circuitBreak(String fieldName, long bytesNeeded) { + throw new CircuitBreakingException(fieldName, Durability.TRANSIENT); + } + + @Override + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + synchronized (this) { + if (bytes + currentBytes >= maxBytes) { + trippedCount++; + circuitBreak(label, bytes); + } + currentBytes += bytes; + return currentBytes; + } + } + + @Override + public long addWithoutBreaking(long bytes) { + synchronized (this) { + currentBytes += bytes; + return currentBytes; + } + } + + @Override + public long getUsed() { + return currentBytes; + } + + @Override + public long getLimit() { + return maxBytes; + } + + @Override + public double getOverhead() { + return 1.0; + } + + @Override + public long getTrippedCount() { + synchronized (this) { + return trippedCount; + } + } + + @Override + public String getName() { + return MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; + } + + @Override + public Durability getDurability() { + return Durability.TRANSIENT; + } + + @Override + public void setLimitAndOverhead(long limit, double overhead) { + throw new UnsupportedOperationException("boom"); + } + } + private static class ModelLoadedTracker { private final Set expectedModelIds; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java index 45bb5a5405e..aeee77bd458 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java @@ -60,6 +60,7 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts; import org.elasticsearch.xpack.ilm.IndexLifecycle; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.monitoring.MonitoringService; import org.junit.After; import org.junit.Before; @@ -98,6 +99,8 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase { settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial"); settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false); settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false); + settings.put(MonitoringService.ENABLED.getKey(), false); + settings.put(MonitoringService.ELASTICSEARCH_COLLECTION_ENABLED.getKey(), false); settings.put(LifecycleSettings.LIFECYCLE_HISTORY_INDEX_ENABLED_SETTING.getKey(), false); return settings.build(); }