diff --git a/docs/reference/settings/ml-settings.asciidoc b/docs/reference/settings/ml-settings.asciidoc index 7cdeba856f2..73bfcd36463 100644 --- a/docs/reference/settings/ml-settings.asciidoc +++ b/docs/reference/settings/ml-settings.asciidoc @@ -56,6 +56,7 @@ The maximum inference cache size allowed. The inference cache exists in the JVM heap on each ingest node. The cache affords faster processing times for the `inference` processor. The value can be a static byte sized value (i.e. "2gb") or a percentage of total allocated heap. The default is "40%". +See also <>. `xpack.ml.inference_model.time_to_live`:: The time to live (TTL) for models in the inference model cache. The TTL is @@ -137,3 +138,24 @@ to the {es} JVM. When such processes are started they must connect to the {es} JVM. If such a process does not connect within the time period specified by this setting then the process is assumed to have failed. Defaults to `10s`. The minimum value for this setting is `5s`. + +[[model-inference-circuit-breaker]] +==== {ml-cap} circuit breaker settings + +`breaker.model_inference.limit` (<>) +Limit for model inference breaker, defaults to 50% of JVM heap. +If the parent circuit breaker is less than 50% of JVM heap, it is bound +to that limit instead. +See <>. + +`breaker.model_inference.overhead` (<>) +A constant that all accounting estimations are multiplied with to determine +a final estimation. Defaults to 1. +See <>. + +`breaker.model_inference.type` +The underlying type of the circuit breaker. There are two valid options: +`noop`, meaning the circuit breaker does nothing to prevent too much memory usage, +`memory`, meaning the circuit breaker tracks the memory used by inference models and +could potentially break and prevent OutOfMemory errors. +The default is `memory`. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index bd037714991..1372db30e5b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Module; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; @@ -40,12 +41,15 @@ import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider; +import org.elasticsearch.indices.breaker.BreakerSettings; import org.elasticsearch.ingest.Processor; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.monitor.os.OsProbe; import org.elasticsearch.monitor.os.OsStats; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.CircuitBreakerPlugin; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; @@ -323,7 +327,11 @@ import java.util.function.UnaryOperator; import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; -public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin { +public class MachineLearning extends Plugin implements SystemIndexPlugin, + AnalysisPlugin, + CircuitBreakerPlugin, + IngestPlugin, + PersistentTaskPlugin { public static final String NAME = "ml"; public static final String BASE_PATH = "/_ml/"; public static final String PRE_V7_BASE_PATH = "/_xpack/ml/"; @@ -331,6 +339,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys public static final String JOB_COMMS_THREAD_POOL_NAME = NAME + "_job_comms"; public static final String UTILITY_THREAD_POOL_NAME = NAME + "_utility"; + public static final String TRAINED_MODEL_CIRCUIT_BREAKER_NAME = "model_inference"; + + private static final long DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT = (long)((0.50) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()); + private static final double DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD = 1.0D; // This is for performance testing. It's not exposed to the end user. // Recompile if you want to compare performance with C++ tokenization. public static final boolean CATEGORIZATION_TOKENIZATION_IN_JAVA = true; @@ -436,6 +448,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys private final SetOnce dataFrameAnalyticsAuditor = new SetOnce<>(); private final SetOnce memoryTracker = new SetOnce<>(); private final SetOnce mlUpgradeModeActionFilter = new SetOnce<>(); + private final SetOnce inferenceModelBreaker = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -661,10 +674,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys inferenceAuditor, threadPool, clusterService, - xContentRegistry, trainedModelStatsService, settings, - clusterService.getNodeName()); + clusterService.getNodeName(), + inferenceModelBreaker.get()); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, @@ -1001,4 +1014,23 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys new SystemIndexDescriptor(InferenceIndexConstants.INDEX_PATTERN, "Contains ML model configuration and statistics") )); } + + @Override + public BreakerSettings getCircuitBreaker(Settings settings) { + return BreakerSettings.updateFromSettings( + new BreakerSettings( + TRAINED_MODEL_CIRCUIT_BREAKER_NAME, + DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT, + DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD, + CircuitBreaker.Type.MEMORY, + CircuitBreaker.Durability.TRANSIENT + ), + settings); + } + + @Override + public void setCircuitBreaker(CircuitBreaker circuitBreaker) { + assert circuitBreaker.getName().equals(TRAINED_MODEL_CIRCUIT_BREAKER_NAME); + this.inferenceModelBreaker.set(circuitBreaker); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 7784ac097da..436cdfd54f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -15,6 +15,8 @@ import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.common.cache.RemovalNotification; @@ -24,7 +26,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -95,15 +96,16 @@ public class ModelLoadingService implements ClusterStateListener { private final InferenceAuditor auditor; private final ByteSizeValue maxCacheSize; private final String localNode; + private final CircuitBreaker trainedModelCircuitBreaker; public ModelLoadingService(TrainedModelProvider trainedModelProvider, InferenceAuditor auditor, ThreadPool threadPool, ClusterService clusterService, - NamedXContentRegistry namedXContentRegistry, TrainedModelStatsService modelStatsService, Settings settings, - String localNode) { + String localNode, + CircuitBreaker trainedModelCircuitBreaker) { this.provider = trainedModelProvider; this.threadPool = threadPool; this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings); @@ -119,6 +121,7 @@ public class ModelLoadingService implements ClusterStateListener { .build(); clusterService.addListener(this); this.localNode = localNode; + this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker"); } /** @@ -149,21 +152,32 @@ public class ModelLoadingService implements ClusterStateListener { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); - provider.getTrainedModelForInference(modelId, ActionListener.wrap( - configAndInferenceDef -> { - TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1(); - InferenceDefinition inferenceDefinition = configAndInferenceDef.v2(); - InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? - inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : - trainedModelConfig.getInferenceConfig(); - modelActionListener.onResponse(new LocalModel( - trainedModelConfig.getModelId(), - localNode, - inferenceDefinition, - trainedModelConfig.getInput(), - trainedModelConfig.getDefaultFieldMap(), - inferenceConfig, - modelStatsService)); + provider.getTrainedModel(modelId, false, ActionListener.wrap( + trainedModelConfig -> { + // Verify we can pull the model into memory without causing OOM + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); + provider.getTrainedModelForInference(modelId, ActionListener.wrap( + inferenceDefinition -> { + InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? + inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : + trainedModelConfig.getInferenceConfig(); + // Remove the bytes as we cannot control how long the caller will keep the model in memory + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onResponse(new LocalModel( + trainedModelConfig.getModelId(), + localNode, + inferenceDefinition, + trainedModelConfig.getInput(), + trainedModelConfig.getDefaultFieldMap(), + inferenceConfig, + modelStatsService)); + }, + // Failure getting the definition, remove the initial estimation value + e -> { + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + modelActionListener.onFailure(e); + } + )); }, modelActionListener::onFailure )); @@ -205,29 +219,53 @@ public class ModelLoadingService implements ClusterStateListener { } private void loadModel(String modelId) { - provider.getTrainedModelForInference(modelId, ActionListener.wrap( - configAndInferenceDef -> { - logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId)); - handleLoadSuccess(modelId, configAndInferenceDef); + provider.getTrainedModel(modelId, false, ActionListener.wrap( + trainedModelConfig -> { + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); + provider.getTrainedModelForInference(modelId, ActionListener.wrap( + inferenceDefinition -> { + // Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory + // So that the memory this model uses in the circuit breaker is the most accurate estimate. + long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory(); + if (estimateDiff < 0) { + trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff); + } else if (estimateDiff > 0) { // rare case where estimate is now HIGHER + try { + trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId); + } catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + handleLoadFailure(modelId, ex); + return; + } + } + handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition); + }, + failure -> { + // We failed to get the definition, remove the initial estimation. + trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory()); + logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure); + handleLoadFailure(modelId, failure); + } + )); }, failure -> { - logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure); + logger.warn(new ParameterizedMessage("[{}] failed to load model configuration", modelId), failure); handleLoadFailure(modelId, failure); } )); } private void handleLoadSuccess(String modelId, - Tuple configAndInferenceDef) { + TrainedModelConfig trainedModelConfig, + InferenceDefinition inferenceDefinition) { Queue> listeners; - TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1(); InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? - inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) : + inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig(); LocalModel loadedModel = new LocalModel( trainedModelConfig.getModelId(), localNode, - configAndInferenceDef.v2(), + inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, @@ -237,6 +275,7 @@ public class ModelLoadingService implements ClusterStateListener { // If there is no loadingListener that means the loading was canceled and the listener was already notified as such // Consequently, we should not store the retrieved model if (listeners == null) { + trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed()); return; } localModelCache.put(modelId, loadedModel); @@ -263,20 +302,24 @@ public class ModelLoadingService implements ClusterStateListener { } private void cacheEvictionListener(RemovalNotification notification) { - if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { - MessageSupplier msg = () -> new ParameterizedMessage( - "model cache entry evicted." + - "current cache [{}] current max [{}] model size [{}]. " + - "If this is undesired, consider updating setting [{}] or [{}].", - new ByteSizeValue(localModelCache.weight()).getStringRep(), - maxCacheSize.getStringRep(), - new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), - INFERENCE_MODEL_CACHE_SIZE.getKey(), - INFERENCE_MODEL_CACHE_TTL.getKey()); - auditIfNecessary(notification.getKey(), msg); + try { + if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { + MessageSupplier msg = () -> new ParameterizedMessage( + "model cache entry evicted." + + "current cache [{}] current max [{}] model size [{}]. " + + "If this is undesired, consider updating setting [{}] or [{}].", + new ByteSizeValue(localModelCache.weight()).getStringRep(), + maxCacheSize.getStringRep(), + new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(), + INFERENCE_MODEL_CACHE_SIZE.getKey(), + INFERENCE_MODEL_CACHE_TTL.getKey()); + auditIfNecessary(notification.getKey(), msg); + } + // If the model is no longer referenced, flush the stats to persist as soon as possible + notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); + } finally { + trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed()); } - // If the model is no longer referenced, flush the stats to persist as soon as possible - notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 284ed4ea67d..5585fe09dc8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -232,19 +232,17 @@ public class TrainedModelProvider { executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener); } - public void getTrainedModelForInference(final String modelId, - final ActionListener> listener) { + public void getTrainedModelForInference(final String modelId, final ActionListener listener) { // TODO Change this when we get more than just langIdent stored if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry); assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; - listener.onResponse(Tuple.tuple( - config, + listener.onResponse( InferenceDefinition.builder() .setPreProcessors(config.getModelDefinition().getPreProcessors()) .setTrainedModel((LangIdentNeuralNetwork)config.getModelDefinition().getTrainedModel()) - .build())); + .build()); return; } catch (ElasticsearchException|IOException ex) { listener.onFailure(ex); @@ -252,46 +250,52 @@ public class TrainedModelProvider { } } - getTrainedModel(modelId, false, ActionListener.wrap( - config -> { - SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) - .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders - .boolQuery() - .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) - .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), - TrainedModelDefinitionDoc.NAME)))) - .setSize(MAX_NUM_DEFINITION_DOCS) - // First find the latest index - .addSort("_index", SortOrder.DESC) - // Then, sort by doc_num - .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) - .order(SortOrder.ASC) - .unmappedType("long")) - .request(); - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( - searchResponse -> { - List docs = handleHits(searchResponse.getHits().getHits(), - modelId, - this::parseModelDefinitionDocLenientlyFromSource); - String compressedString = docs.stream() - .map(TrainedModelDefinitionDoc::getCompressedString) - .collect(Collectors.joining()); - if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { - listener.onFailure(ExceptionsHelper.serverError( - Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); - return; - } - InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedString, - InferenceDefinition::fromXContent, - xContentRegistry); - listener.onResponse(Tuple.tuple(config, inferenceDefinition)); - }, - listener::onFailure - )); - + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders + .boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), + TrainedModelDefinitionDoc.NAME)))) + .setSize(MAX_NUM_DEFINITION_DOCS) + // First find the latest index + .addSort("_index", SortOrder.DESC) + // Then, sort by doc_num + .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()) + .order(SortOrder.ASC) + .unmappedType("long")) + .request(); + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return; + } + List docs = handleHits(searchResponse.getHits().getHits(), + modelId, + this::parseModelDefinitionDocLenientlyFromSource); + String compressedString = docs.stream() + .map(TrainedModelDefinitionDoc::getCompressedString) + .collect(Collectors.joining()); + if (compressedString.length() != docs.get(0).getTotalDefinitionLength()) { + listener.onFailure(ExceptionsHelper.serverError( + Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId))); + return; + } + InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( + compressedString, + InferenceDefinition::fromXContent, + xContentRegistry); + listener.onResponse(inferenceDefinition); }, - listener::onFailure + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return; + } + listener.onFailure(e); + } )); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index f1c1144d1f3..8d3299c828d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.license.LicenseService; @@ -29,18 +30,22 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; + public class LocalStateMachineLearning extends LocalStateCompositeXPackPlugin { public LocalStateMachineLearning(final Settings settings, final Path configPath) throws Exception { super(settings, configPath); LocalStateMachineLearning thisVar = this; + MachineLearning plugin = new MachineLearning(settings, configPath){ - plugins.add(new MachineLearning(settings, configPath){ @Override protected XPackLicenseState getLicenseState() { return thisVar.getLicenseState(); } - }); + }; + plugin.setCircuitBreaker(new NoopCircuitBreaker(TRAINED_MODEL_CIRCUIT_BREAKER_NAME)); + plugins.add(plugin); plugins.add(new Monitoring(settings) { @Override protected SSLService getSslService() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index c525989e7f1..98f38bcd52c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -18,13 +18,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -61,6 +62,7 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -83,6 +85,7 @@ public class ModelLoadingServiceTests extends ESTestCase { private ClusterService clusterService; private InferenceAuditor auditor; private TrainedModelStatsService trainedModelStatsService; + private CircuitBreaker circuitBreaker; @Before public void setUpComponents() { @@ -97,6 +100,7 @@ public class ModelLoadingServiceTests extends ESTestCase { doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class)); doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); + circuitBreaker = new CustomCircuitBreaker(1000); } @After @@ -116,10 +120,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); @@ -163,10 +167,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(), - "test-node"); + "test-node", + circuitBreaker); // We want to be notified when the models are loaded which happens in a background thread ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds)); @@ -279,10 +283,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); modelLoadingService.clusterChanged(ingestChangedEvent(false, model1)); @@ -304,10 +308,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); modelLoadingService.clusterChanged(ingestChangedEvent(model)); PlainActionFuture future = new PlainActionFuture<>(); @@ -332,10 +336,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); PlainActionFuture future = new PlainActionFuture<>(); modelLoadingService.getModel(model, future); @@ -355,10 +359,10 @@ public class ModelLoadingServiceTests extends ESTestCase { auditor, threadPool, clusterService, - NamedXContentRegistry.EMPTY, trainedModelStatsService, Settings.EMPTY, - "test-node"); + "test-node", + circuitBreaker); for(int i = 0; i < 3; i++) { PlainActionFuture future = new PlainActionFuture<>(); @@ -370,6 +374,50 @@ public class ModelLoadingServiceTests extends ESTestCase { verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } + public void testCircuitBreakerBreak() throws Exception { + String model1 = "test-circuit-break-model-1"; + String model2 = "test-circuit-break-model-2"; + String model3 = "test-circuit-break-model-3"; + withTrainedModel(model1, 5L); + withTrainedModel(model2, 5L); + withTrainedModel(model3, 12L); + CircuitBreaker circuitBreaker = new CustomCircuitBreaker(11); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.addModelLoadedListener(model3, ActionListener.wrap( + r -> fail("Should not have succeeded to load model as breaker should be reached"), + e -> assertThat(e, instanceOf(CircuitBreakingException.class)) + )); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); + + // Should have been loaded from the cluster change event but it is unknown in what order + // the loading occurred or which models are currently in the cache due to evictions. + // Verify that we have at least loaded all three + assertBusy(() -> { + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any()); + }); + assertBusy(() -> { + assertThat(circuitBreaker.getUsed(), equalTo(10L)); + assertThat(circuitBreaker.getTrippedCount(), equalTo(1L)); + }); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1)); + + assertBusy(() -> { + assertThat(circuitBreaker.getUsed(), equalTo(5L)); + }); + } + @SuppressWarnings("unchecked") private void withTrainedModel(String modelId, long size) { InferenceDefinition definition = mock(InferenceDefinition.class); @@ -378,15 +426,48 @@ public class ModelLoadingServiceTests extends ESTestCase { when(trainedModelConfig.getModelId()).thenReturn(modelId); when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS); when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz"))); + when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - listener.onResponse(Tuple.tuple(trainedModelConfig, definition)); + listener.onResponse(definition); return null; }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); } + @SuppressWarnings("unchecked") private void withMissingModel(String modelId) { + if (randomBoolean()) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + } else { + TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class); + when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return null; + }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + } doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -438,6 +519,79 @@ public class ModelLoadingServiceTests extends ESTestCase { } } + private static class CustomCircuitBreaker implements CircuitBreaker { + + private final long maxBytes; + private long currentBytes = 0; + private long trippedCount = 0; + + CustomCircuitBreaker(long maxBytes) { + this.maxBytes = maxBytes; + } + + @Override + public void circuitBreak(String fieldName, long bytesNeeded) { + throw new CircuitBreakingException(fieldName, Durability.TRANSIENT); + } + + @Override + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + synchronized (this) { + if (bytes + currentBytes >= maxBytes) { + trippedCount++; + circuitBreak(label, bytes); + } + currentBytes += bytes; + return currentBytes; + } + } + + @Override + public long addWithoutBreaking(long bytes) { + synchronized (this) { + currentBytes += bytes; + return currentBytes; + } + } + + @Override + public long getUsed() { + return currentBytes; + } + + @Override + public long getLimit() { + return maxBytes; + } + + @Override + public double getOverhead() { + return 1.0; + } + + @Override + public long getTrippedCount() { + synchronized (this) { + return trippedCount; + } + } + + @Override + public String getName() { + return MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; + } + + @Override + public Durability getDurability() { + return Durability.TRANSIENT; + } + + @Override + public void setLimitAndOverhead(long limit, double overhead) { + throw new UnsupportedOperationException("boom"); + } + } + private static class ModelLoadedTracker { private final Set expectedModelIds; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java index 45bb5a5405e..aeee77bd458 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java @@ -60,6 +60,7 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts; import org.elasticsearch.xpack.ilm.IndexLifecycle; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.monitoring.MonitoringService; import org.junit.After; import org.junit.Before; @@ -98,6 +99,8 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase { settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial"); settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false); settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false); + settings.put(MonitoringService.ENABLED.getKey(), false); + settings.put(MonitoringService.ELASTICSEARCH_COLLECTION_ENABLED.getKey(), false); settings.put(LifecycleSettings.LIFECYCLE_HISTORY_INDEX_ENABLED_SETTING.getKey(), false); return settings.build(); }