[ML] add new circuit breaker for inference model caching (#57731) (#57830)

This adds new plugin level circuit breaker for the ML plugin.

`model_inference` is the circuit breaker qualified name.

Right now it simply adds to the breaker when the model is loaded (and possibly breaking) and removing from the breaker when the model is unloaded.
This commit is contained in:
Benjamin Trent 2020-06-08 16:02:48 -04:00 committed by GitHub
parent 6e8cf0973f
commit d5522c2747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 367 additions and 104 deletions

View File

@ -56,6 +56,7 @@ The maximum inference cache size allowed. The inference cache exists in the JVM
heap on each ingest node. The cache affords faster processing times for the
`inference` processor. The value can be a static byte sized value (i.e. "2gb")
or a percentage of total allocated heap. The default is "40%".
See also <<model-inference-circuit-breaker>>.
`xpack.ml.inference_model.time_to_live`::
The time to live (TTL) for models in the inference model cache. The TTL is
@ -137,3 +138,24 @@ to the {es} JVM. When such processes are started they must connect to the {es}
JVM. If such a process does not connect within the time period specified by this
setting then the process is assumed to have failed. Defaults to `10s`. The minimum
value for this setting is `5s`.
[[model-inference-circuit-breaker]]
==== {ml-cap} circuit breaker settings
`breaker.model_inference.limit` (<<cluster-update-settings,Dynamic>>)
Limit for model inference breaker, defaults to 50% of JVM heap.
If the parent circuit breaker is less than 50% of JVM heap, it is bound
to that limit instead.
See <<circuit-breaker>>.
`breaker.model_inference.overhead` (<<cluster-update-settings,Dynamic>>)
A constant that all accounting estimations are multiplied with to determine
a final estimation. Defaults to 1.
See <<circuit-breaker>>.
`breaker.model_inference.type`
The underlying type of the circuit breaker. There are two valid options:
`noop`, meaning the circuit breaker does nothing to prevent too much memory usage,
`memory`, meaning the circuit breaker tracks the memory used by inference models and
could potentially break and prevent OutOfMemory errors.
The default is `memory`.

View File

@ -23,6 +23,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Module;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
@ -40,12 +41,15 @@ import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
import org.elasticsearch.indices.breaker.BreakerSettings;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.monitor.os.OsProbe;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.persistent.PersistentTasksExecutor;
import org.elasticsearch.plugins.AnalysisPlugin;
import org.elasticsearch.plugins.CircuitBreakerPlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
@ -323,7 +327,11 @@ import java.util.function.UnaryOperator;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements SystemIndexPlugin,
AnalysisPlugin,
CircuitBreakerPlugin,
IngestPlugin,
PersistentTaskPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
@ -331,6 +339,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
public static final String JOB_COMMS_THREAD_POOL_NAME = NAME + "_job_comms";
public static final String UTILITY_THREAD_POOL_NAME = NAME + "_utility";
public static final String TRAINED_MODEL_CIRCUIT_BREAKER_NAME = "model_inference";
private static final long DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT = (long)((0.50) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes());
private static final double DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD = 1.0D;
// This is for performance testing. It's not exposed to the end user.
// Recompile if you want to compare performance with C++ tokenization.
public static final boolean CATEGORIZATION_TOKENIZATION_IN_JAVA = true;
@ -436,6 +448,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
private final SetOnce<CircuitBreaker> inferenceModelBreaker = new SetOnce<>();
public MachineLearning(Settings settings, Path configPath) {
this.settings = settings;
@ -661,10 +674,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
inferenceAuditor,
threadPool,
clusterService,
xContentRegistry,
trainedModelStatsService,
settings,
clusterService.getNodeName());
clusterService.getNodeName(),
inferenceModelBreaker.get());
// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
@ -1001,4 +1014,23 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
new SystemIndexDescriptor(InferenceIndexConstants.INDEX_PATTERN, "Contains ML model configuration and statistics")
));
}
@Override
public BreakerSettings getCircuitBreaker(Settings settings) {
return BreakerSettings.updateFromSettings(
new BreakerSettings(
TRAINED_MODEL_CIRCUIT_BREAKER_NAME,
DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT,
DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD,
CircuitBreaker.Type.MEMORY,
CircuitBreaker.Durability.TRANSIENT
),
settings);
}
@Override
public void setCircuitBreaker(CircuitBreaker circuitBreaker) {
assert circuitBreaker.getName().equals(TRAINED_MODEL_CIRCUIT_BREAKER_NAME);
this.inferenceModelBreaker.set(circuitBreaker);
}
}

View File

@ -15,6 +15,8 @@ import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.RemovalNotification;
@ -24,7 +26,6 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@ -95,15 +96,16 @@ public class ModelLoadingService implements ClusterStateListener {
private final InferenceAuditor auditor;
private final ByteSizeValue maxCacheSize;
private final String localNode;
private final CircuitBreaker trainedModelCircuitBreaker;
public ModelLoadingService(TrainedModelProvider trainedModelProvider,
InferenceAuditor auditor,
ThreadPool threadPool,
ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry,
TrainedModelStatsService modelStatsService,
Settings settings,
String localNode) {
String localNode,
CircuitBreaker trainedModelCircuitBreaker) {
this.provider = trainedModelProvider;
this.threadPool = threadPool;
this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings);
@ -119,6 +121,7 @@ public class ModelLoadingService implements ClusterStateListener {
.build();
clusterService.addListener(this);
this.localNode = localNode;
this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker");
}
/**
@ -149,13 +152,17 @@ public class ModelLoadingService implements ClusterStateListener {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
provider.getTrainedModel(modelId, false, ActionListener.wrap(
trainedModelConfig -> {
// Verify we can pull the model into memory without causing OOM
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
configAndInferenceDef -> {
TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1();
InferenceDefinition inferenceDefinition = configAndInferenceDef.v2();
inferenceDefinition -> {
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
trainedModelConfig.getInferenceConfig();
// Remove the bytes as we cannot control how long the caller will keep the model in memory
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
modelActionListener.onResponse(new LocalModel(
trainedModelConfig.getModelId(),
localNode,
@ -165,6 +172,13 @@ public class ModelLoadingService implements ClusterStateListener {
inferenceConfig,
modelStatsService));
},
// Failure getting the definition, remove the initial estimation value
e -> {
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
modelActionListener.onFailure(e);
}
));
},
modelActionListener::onFailure
));
} else {
@ -205,29 +219,53 @@ public class ModelLoadingService implements ClusterStateListener {
}
private void loadModel(String modelId) {
provider.getTrainedModel(modelId, false, ActionListener.wrap(
trainedModelConfig -> {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
configAndInferenceDef -> {
logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId));
handleLoadSuccess(modelId, configAndInferenceDef);
inferenceDefinition -> {
// Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory
// So that the memory this model uses in the circuit breaker is the most accurate estimate.
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
if (estimateDiff < 0) {
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
try {
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
handleLoadFailure(modelId, ex);
return;
}
}
handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition);
},
failure -> {
logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure);
// We failed to get the definition, remove the initial estimation.
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
},
failure -> {
logger.warn(new ParameterizedMessage("[{}] failed to load model configuration", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
}
private void handleLoadSuccess(String modelId,
Tuple<TrainedModelConfig, InferenceDefinition> configAndInferenceDef) {
TrainedModelConfig trainedModelConfig,
InferenceDefinition inferenceDefinition) {
Queue<ActionListener<Model>> listeners;
TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1();
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
trainedModelConfig.getInferenceConfig();
LocalModel loadedModel = new LocalModel(
trainedModelConfig.getModelId(),
localNode,
configAndInferenceDef.v2(),
inferenceDefinition,
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
@ -237,6 +275,7 @@ public class ModelLoadingService implements ClusterStateListener {
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
// Consequently, we should not store the retrieved model
if (listeners == null) {
trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed());
return;
}
localModelCache.put(modelId, loadedModel);
@ -263,6 +302,7 @@ public class ModelLoadingService implements ClusterStateListener {
}
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
try {
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
MessageSupplier msg = () -> new ParameterizedMessage(
"model cache entry evicted." +
@ -277,6 +317,9 @@ public class ModelLoadingService implements ClusterStateListener {
}
// If the model is no longer referenced, flush the stats to persist as soon as possible
notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false);
} finally {
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed());
}
}
@Override

View File

@ -232,19 +232,17 @@ public class TrainedModelProvider {
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
}
public void getTrainedModelForInference(final String modelId,
final ActionListener<Tuple<TrainedModelConfig, InferenceDefinition>> listener) {
public void getTrainedModelForInference(final String modelId, final ActionListener<InferenceDefinition> listener) {
// TODO Change this when we get more than just langIdent stored
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
try {
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
listener.onResponse(Tuple.tuple(
config,
listener.onResponse(
InferenceDefinition.builder()
.setPreProcessors(config.getModelDefinition().getPreProcessors())
.setTrainedModel((LangIdentNeuralNetwork)config.getModelDefinition().getTrainedModel())
.build()));
.build());
return;
} catch (ElasticsearchException|IOException ex) {
listener.onFailure(ex);
@ -252,8 +250,6 @@ public class TrainedModelProvider {
}
}
getTrainedModel(modelId, false, ActionListener.wrap(
config -> {
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
.boolQuery()
@ -270,6 +266,11 @@ public class TrainedModelProvider {
.request();
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
searchResponse -> {
if (searchResponse.getHits().getHits().length == 0) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
}
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
modelId,
this::parseModelDefinitionDocLenientlyFromSource);
@ -285,13 +286,16 @@ public class TrainedModelProvider {
compressedString,
InferenceDefinition::fromXContent,
xContentRegistry);
listener.onResponse(Tuple.tuple(config, inferenceDefinition));
listener.onResponse(inferenceDefinition);
},
listener::onFailure
));
},
listener::onFailure
e -> {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
}
listener.onFailure(e);
}
));
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.LicenseService;
@ -29,18 +30,22 @@ import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME;
public class LocalStateMachineLearning extends LocalStateCompositeXPackPlugin {
public LocalStateMachineLearning(final Settings settings, final Path configPath) throws Exception {
super(settings, configPath);
LocalStateMachineLearning thisVar = this;
MachineLearning plugin = new MachineLearning(settings, configPath){
plugins.add(new MachineLearning(settings, configPath){
@Override
protected XPackLicenseState getLicenseState() {
return thisVar.getLicenseState();
}
});
};
plugin.setCircuitBreaker(new NoopCircuitBreaker(TRAINED_MODEL_CIRCUIT_BREAKER_NAME));
plugins.add(plugin);
plugins.add(new Monitoring(settings) {
@Override
protected SSLService getSslService() {

View File

@ -18,13 +18,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -61,6 +62,7 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
@ -83,6 +85,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
private ClusterService clusterService;
private InferenceAuditor auditor;
private TrainedModelStatsService trainedModelStatsService;
private CircuitBreaker circuitBreaker;
@Before
public void setUpComponents() {
@ -97,6 +100,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class));
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build());
circuitBreaker = new CustomCircuitBreaker(1000);
}
@After
@ -116,10 +120,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
auditor,
threadPool,
clusterService,
NamedXContentRegistry.EMPTY,
trainedModelStatsService,
Settings.EMPTY,
"test-node");
"test-node",
circuitBreaker);
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
@ -163,10 +167,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
auditor,
threadPool,
clusterService,
NamedXContentRegistry.EMPTY,
trainedModelStatsService,
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
"test-node");
"test-node",
circuitBreaker);
// We want to be notified when the models are loaded which happens in a background thread
ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds));
@ -279,10 +283,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
auditor,
threadPool,
clusterService,
NamedXContentRegistry.EMPTY,
trainedModelStatsService,
Settings.EMPTY,
"test-node");
"test-node",
circuitBreaker);
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
@ -304,10 +308,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
auditor,
threadPool,
clusterService,
NamedXContentRegistry.EMPTY,
trainedModelStatsService,
Settings.EMPTY,
"test-node");
"test-node",
circuitBreaker);
modelLoadingService.clusterChanged(ingestChangedEvent(model));
PlainActionFuture<Model> future = new PlainActionFuture<>();
@ -332,10 +336,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
auditor,
threadPool,
clusterService,
NamedXContentRegistry.EMPTY,
trainedModelStatsService,
Settings.EMPTY,
"test-node");
"test-node",
circuitBreaker);
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
@ -355,10 +359,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
auditor,
threadPool,
clusterService,
NamedXContentRegistry.EMPTY,
trainedModelStatsService,
Settings.EMPTY,
"test-node");
"test-node",
circuitBreaker);
for(int i = 0; i < 3; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
@ -370,6 +374,50 @@ public class ModelLoadingServiceTests extends ESTestCase {
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
public void testCircuitBreakerBreak() throws Exception {
String model1 = "test-circuit-break-model-1";
String model2 = "test-circuit-break-model-2";
String model3 = "test-circuit-break-model-3";
withTrainedModel(model1, 5L);
withTrainedModel(model2, 5L);
withTrainedModel(model3, 12L);
CircuitBreaker circuitBreaker = new CustomCircuitBreaker(11);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
trainedModelStatsService,
Settings.EMPTY,
"test-node",
circuitBreaker);
modelLoadingService.addModelLoadedListener(model3, ActionListener.wrap(
r -> fail("Should not have succeeded to load model as breaker should be reached"),
e -> assertThat(e, instanceOf(CircuitBreakingException.class))
));
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
// Should have been loaded from the cluster change event but it is unknown in what order
// the loading occurred or which models are currently in the cache due to evictions.
// Verify that we have at least loaded all three
assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
});
assertBusy(() -> {
assertThat(circuitBreaker.getUsed(), equalTo(10L));
assertThat(circuitBreaker.getTrippedCount(), equalTo(1L));
});
modelLoadingService.clusterChanged(ingestChangedEvent(model1));
assertBusy(() -> {
assertThat(circuitBreaker.getUsed(), equalTo(5L));
});
}
@SuppressWarnings("unchecked")
private void withTrainedModel(String modelId, long size) {
InferenceDefinition definition = mock(InferenceDefinition.class);
@ -378,15 +426,48 @@ public class ModelLoadingServiceTests extends ESTestCase {
when(trainedModelConfig.getModelId()).thenReturn(modelId);
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size);
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
listener.onResponse(Tuple.tuple(trainedModelConfig, definition));
listener.onResponse(definition);
return null;
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
}
@SuppressWarnings("unchecked")
private void withMissingModel(String modelId) {
if (randomBoolean()) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
} else {
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return null;
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
}
private void withMissingModel(String modelId) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
@ -438,6 +519,79 @@ public class ModelLoadingServiceTests extends ESTestCase {
}
}
private static class CustomCircuitBreaker implements CircuitBreaker {
private final long maxBytes;
private long currentBytes = 0;
private long trippedCount = 0;
CustomCircuitBreaker(long maxBytes) {
this.maxBytes = maxBytes;
}
@Override
public void circuitBreak(String fieldName, long bytesNeeded) {
throw new CircuitBreakingException(fieldName, Durability.TRANSIENT);
}
@Override
public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
synchronized (this) {
if (bytes + currentBytes >= maxBytes) {
trippedCount++;
circuitBreak(label, bytes);
}
currentBytes += bytes;
return currentBytes;
}
}
@Override
public long addWithoutBreaking(long bytes) {
synchronized (this) {
currentBytes += bytes;
return currentBytes;
}
}
@Override
public long getUsed() {
return currentBytes;
}
@Override
public long getLimit() {
return maxBytes;
}
@Override
public double getOverhead() {
return 1.0;
}
@Override
public long getTrippedCount() {
synchronized (this) {
return trippedCount;
}
}
@Override
public String getName() {
return MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME;
}
@Override
public Durability getDurability() {
return Durability.TRANSIENT;
}
@Override
public void setLimitAndOverhead(long limit, double overhead) {
throw new UnsupportedOperationException("boom");
}
}
private static class ModelLoadedTracker {
private final Set<String> expectedModelIds;

View File

@ -60,6 +60,7 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
import org.elasticsearch.xpack.ilm.IndexLifecycle;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.monitoring.MonitoringService;
import org.junit.After;
import org.junit.Before;
@ -98,6 +99,8 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial");
settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false);
settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false);
settings.put(MonitoringService.ENABLED.getKey(), false);
settings.put(MonitoringService.ELASTICSEARCH_COLLECTION_ENABLED.getKey(), false);
settings.put(LifecycleSettings.LIFECYCLE_HISTORY_INDEX_ENABLED_SETTING.getKey(), false);
return settings.build();
}