This adds new plugin level circuit breaker for the ML plugin. `model_inference` is the circuit breaker qualified name. Right now it simply adds to the breaker when the model is loaded (and possibly breaking) and removing from the breaker when the model is unloaded.
This commit is contained in:
parent
6e8cf0973f
commit
d5522c2747
|
@ -56,6 +56,7 @@ The maximum inference cache size allowed. The inference cache exists in the JVM
|
|||
heap on each ingest node. The cache affords faster processing times for the
|
||||
`inference` processor. The value can be a static byte sized value (i.e. "2gb")
|
||||
or a percentage of total allocated heap. The default is "40%".
|
||||
See also <<model-inference-circuit-breaker>>.
|
||||
|
||||
`xpack.ml.inference_model.time_to_live`::
|
||||
The time to live (TTL) for models in the inference model cache. The TTL is
|
||||
|
@ -137,3 +138,24 @@ to the {es} JVM. When such processes are started they must connect to the {es}
|
|||
JVM. If such a process does not connect within the time period specified by this
|
||||
setting then the process is assumed to have failed. Defaults to `10s`. The minimum
|
||||
value for this setting is `5s`.
|
||||
|
||||
[[model-inference-circuit-breaker]]
|
||||
==== {ml-cap} circuit breaker settings
|
||||
|
||||
`breaker.model_inference.limit` (<<cluster-update-settings,Dynamic>>)
|
||||
Limit for model inference breaker, defaults to 50% of JVM heap.
|
||||
If the parent circuit breaker is less than 50% of JVM heap, it is bound
|
||||
to that limit instead.
|
||||
See <<circuit-breaker>>.
|
||||
|
||||
`breaker.model_inference.overhead` (<<cluster-update-settings,Dynamic>>)
|
||||
A constant that all accounting estimations are multiplied with to determine
|
||||
a final estimation. Defaults to 1.
|
||||
See <<circuit-breaker>>.
|
||||
|
||||
`breaker.model_inference.type`
|
||||
The underlying type of the circuit breaker. There are two valid options:
|
||||
`noop`, meaning the circuit breaker does nothing to prevent too much memory usage,
|
||||
`memory`, meaning the circuit breaker tracks the memory used by inference models and
|
||||
could potentially break and prevent OutOfMemory errors.
|
||||
The default is `memory`.
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
|
|||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.inject.Module;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.settings.ClusterSettings;
|
||||
import org.elasticsearch.common.settings.IndexScopedSettings;
|
||||
|
@ -40,12 +41,15 @@ import org.elasticsearch.env.NodeEnvironment;
|
|||
import org.elasticsearch.index.analysis.TokenizerFactory;
|
||||
import org.elasticsearch.indices.SystemIndexDescriptor;
|
||||
import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
|
||||
import org.elasticsearch.indices.breaker.BreakerSettings;
|
||||
import org.elasticsearch.ingest.Processor;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.monitor.jvm.JvmInfo;
|
||||
import org.elasticsearch.monitor.os.OsProbe;
|
||||
import org.elasticsearch.monitor.os.OsStats;
|
||||
import org.elasticsearch.persistent.PersistentTasksExecutor;
|
||||
import org.elasticsearch.plugins.AnalysisPlugin;
|
||||
import org.elasticsearch.plugins.CircuitBreakerPlugin;
|
||||
import org.elasticsearch.plugins.IngestPlugin;
|
||||
import org.elasticsearch.plugins.PersistentTaskPlugin;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
|
@ -323,7 +327,11 @@ import java.util.function.UnaryOperator;
|
|||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
public class MachineLearning extends Plugin implements SystemIndexPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
|
||||
public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
||||
AnalysisPlugin,
|
||||
CircuitBreakerPlugin,
|
||||
IngestPlugin,
|
||||
PersistentTaskPlugin {
|
||||
public static final String NAME = "ml";
|
||||
public static final String BASE_PATH = "/_ml/";
|
||||
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
|
||||
|
@ -331,6 +339,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
public static final String JOB_COMMS_THREAD_POOL_NAME = NAME + "_job_comms";
|
||||
public static final String UTILITY_THREAD_POOL_NAME = NAME + "_utility";
|
||||
|
||||
public static final String TRAINED_MODEL_CIRCUIT_BREAKER_NAME = "model_inference";
|
||||
|
||||
private static final long DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT = (long)((0.50) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes());
|
||||
private static final double DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD = 1.0D;
|
||||
// This is for performance testing. It's not exposed to the end user.
|
||||
// Recompile if you want to compare performance with C++ tokenization.
|
||||
public static final boolean CATEGORIZATION_TOKENIZATION_IN_JAVA = true;
|
||||
|
@ -436,6 +448,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
|
||||
private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
|
||||
private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
|
||||
private final SetOnce<CircuitBreaker> inferenceModelBreaker = new SetOnce<>();
|
||||
|
||||
public MachineLearning(Settings settings, Path configPath) {
|
||||
this.settings = settings;
|
||||
|
@ -661,10 +674,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
inferenceAuditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
xContentRegistry,
|
||||
trainedModelStatsService,
|
||||
settings,
|
||||
clusterService.getNodeName());
|
||||
clusterService.getNodeName(),
|
||||
inferenceModelBreaker.get());
|
||||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
|
||||
|
@ -1001,4 +1014,23 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
new SystemIndexDescriptor(InferenceIndexConstants.INDEX_PATTERN, "Contains ML model configuration and statistics")
|
||||
));
|
||||
}
|
||||
|
||||
@Override
|
||||
public BreakerSettings getCircuitBreaker(Settings settings) {
|
||||
return BreakerSettings.updateFromSettings(
|
||||
new BreakerSettings(
|
||||
TRAINED_MODEL_CIRCUIT_BREAKER_NAME,
|
||||
DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT,
|
||||
DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD,
|
||||
CircuitBreaker.Type.MEMORY,
|
||||
CircuitBreaker.Durability.TRANSIENT
|
||||
),
|
||||
settings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCircuitBreaker(CircuitBreaker circuitBreaker) {
|
||||
assert circuitBreaker.getName().equals(TRAINED_MODEL_CIRCUIT_BREAKER_NAME);
|
||||
this.inferenceModelBreaker.set(circuitBreaker);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ import org.elasticsearch.cluster.ClusterChangedEvent;
|
|||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.ClusterStateListener;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.breaker.CircuitBreakingException;
|
||||
import org.elasticsearch.common.cache.Cache;
|
||||
import org.elasticsearch.common.cache.CacheBuilder;
|
||||
import org.elasticsearch.common.cache.RemovalNotification;
|
||||
|
@ -24,7 +26,6 @@ import org.elasticsearch.common.settings.Settings;
|
|||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.util.set.Sets;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.ingest.IngestMetadata;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||
|
@ -95,15 +96,16 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
private final InferenceAuditor auditor;
|
||||
private final ByteSizeValue maxCacheSize;
|
||||
private final String localNode;
|
||||
private final CircuitBreaker trainedModelCircuitBreaker;
|
||||
|
||||
public ModelLoadingService(TrainedModelProvider trainedModelProvider,
|
||||
InferenceAuditor auditor,
|
||||
ThreadPool threadPool,
|
||||
ClusterService clusterService,
|
||||
NamedXContentRegistry namedXContentRegistry,
|
||||
TrainedModelStatsService modelStatsService,
|
||||
Settings settings,
|
||||
String localNode) {
|
||||
String localNode,
|
||||
CircuitBreaker trainedModelCircuitBreaker) {
|
||||
this.provider = trainedModelProvider;
|
||||
this.threadPool = threadPool;
|
||||
this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings);
|
||||
|
@ -119,6 +121,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
.build();
|
||||
clusterService.addListener(this);
|
||||
this.localNode = localNode;
|
||||
this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker");
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -149,13 +152,17 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
|
||||
// by a simulated pipeline
|
||||
logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
// Verify we can pull the model into memory without causing OOM
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
configAndInferenceDef -> {
|
||||
TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1();
|
||||
InferenceDefinition inferenceDefinition = configAndInferenceDef.v2();
|
||||
inferenceDefinition -> {
|
||||
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
|
||||
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
|
||||
trainedModelConfig.getInferenceConfig();
|
||||
// Remove the bytes as we cannot control how long the caller will keep the model in memory
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
modelActionListener.onResponse(new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
|
@ -165,6 +172,13 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
inferenceConfig,
|
||||
modelStatsService));
|
||||
},
|
||||
// Failure getting the definition, remove the initial estimation value
|
||||
e -> {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
modelActionListener.onFailure(e);
|
||||
}
|
||||
));
|
||||
},
|
||||
modelActionListener::onFailure
|
||||
));
|
||||
} else {
|
||||
|
@ -205,29 +219,53 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
|
||||
private void loadModel(String modelId) {
|
||||
provider.getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
trainedModelConfig -> {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
|
||||
provider.getTrainedModelForInference(modelId, ActionListener.wrap(
|
||||
configAndInferenceDef -> {
|
||||
logger.debug(() -> new ParameterizedMessage("[{}] successfully loaded model", modelId));
|
||||
handleLoadSuccess(modelId, configAndInferenceDef);
|
||||
inferenceDefinition -> {
|
||||
// Since we have used the previously stored estimate to help guard against OOM we need to adjust the memory
|
||||
// So that the memory this model uses in the circuit breaker is the most accurate estimate.
|
||||
long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
|
||||
if (estimateDiff < 0) {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
|
||||
} else if (estimateDiff > 0) { // rare case where estimate is now HIGHER
|
||||
try {
|
||||
trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
|
||||
} catch (CircuitBreakingException ex) { // if we failed here, we should remove the initial estimate as well
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
handleLoadFailure(modelId, ex);
|
||||
return;
|
||||
}
|
||||
}
|
||||
handleLoadSuccess(modelId, trainedModelConfig, inferenceDefinition);
|
||||
},
|
||||
failure -> {
|
||||
logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure);
|
||||
// We failed to get the definition, remove the initial estimation.
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
|
||||
logger.warn(new ParameterizedMessage("[{}] failed to load model definition", modelId), failure);
|
||||
handleLoadFailure(modelId, failure);
|
||||
}
|
||||
));
|
||||
},
|
||||
failure -> {
|
||||
logger.warn(new ParameterizedMessage("[{}] failed to load model configuration", modelId), failure);
|
||||
handleLoadFailure(modelId, failure);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
private void handleLoadSuccess(String modelId,
|
||||
Tuple<TrainedModelConfig, InferenceDefinition> configAndInferenceDef) {
|
||||
TrainedModelConfig trainedModelConfig,
|
||||
InferenceDefinition inferenceDefinition) {
|
||||
Queue<ActionListener<Model>> listeners;
|
||||
TrainedModelConfig trainedModelConfig = configAndInferenceDef.v1();
|
||||
InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ?
|
||||
inferenceConfigFromTargetType(trainedModelConfig.getModelDefinition().getTrainedModel().targetType()) :
|
||||
inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) :
|
||||
trainedModelConfig.getInferenceConfig();
|
||||
LocalModel loadedModel = new LocalModel(
|
||||
trainedModelConfig.getModelId(),
|
||||
localNode,
|
||||
configAndInferenceDef.v2(),
|
||||
inferenceDefinition,
|
||||
trainedModelConfig.getInput(),
|
||||
trainedModelConfig.getDefaultFieldMap(),
|
||||
inferenceConfig,
|
||||
|
@ -237,6 +275,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
|
||||
// Consequently, we should not store the retrieved model
|
||||
if (listeners == null) {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-inferenceDefinition.ramBytesUsed());
|
||||
return;
|
||||
}
|
||||
localModelCache.put(modelId, loadedModel);
|
||||
|
@ -263,6 +302,7 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
|
||||
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
|
||||
try {
|
||||
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
|
||||
MessageSupplier msg = () -> new ParameterizedMessage(
|
||||
"model cache entry evicted." +
|
||||
|
@ -277,6 +317,9 @@ public class ModelLoadingService implements ClusterStateListener {
|
|||
}
|
||||
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
||||
notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||
} finally {
|
||||
trainedModelCircuitBreaker.addWithoutBreaking(-notification.getValue().ramBytesUsed());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -232,19 +232,17 @@ public class TrainedModelProvider {
|
|||
executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
|
||||
}
|
||||
|
||||
public void getTrainedModelForInference(final String modelId,
|
||||
final ActionListener<Tuple<TrainedModelConfig, InferenceDefinition>> listener) {
|
||||
public void getTrainedModelForInference(final String modelId, final ActionListener<InferenceDefinition> listener) {
|
||||
// TODO Change this when we get more than just langIdent stored
|
||||
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||
try {
|
||||
TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
|
||||
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
|
||||
listener.onResponse(Tuple.tuple(
|
||||
config,
|
||||
listener.onResponse(
|
||||
InferenceDefinition.builder()
|
||||
.setPreProcessors(config.getModelDefinition().getPreProcessors())
|
||||
.setTrainedModel((LangIdentNeuralNetwork)config.getModelDefinition().getTrainedModel())
|
||||
.build()));
|
||||
.build());
|
||||
return;
|
||||
} catch (ElasticsearchException|IOException ex) {
|
||||
listener.onFailure(ex);
|
||||
|
@ -252,8 +250,6 @@ public class TrainedModelProvider {
|
|||
}
|
||||
}
|
||||
|
||||
getTrainedModel(modelId, false, ActionListener.wrap(
|
||||
config -> {
|
||||
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
||||
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
||||
.boolQuery()
|
||||
|
@ -270,6 +266,11 @@ public class TrainedModelProvider {
|
|||
.request();
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
if (searchResponse.getHits().getHits().length == 0) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
return;
|
||||
}
|
||||
List<TrainedModelDefinitionDoc> docs = handleHits(searchResponse.getHits().getHits(),
|
||||
modelId,
|
||||
this::parseModelDefinitionDocLenientlyFromSource);
|
||||
|
@ -285,13 +286,16 @@ public class TrainedModelProvider {
|
|||
compressedString,
|
||||
InferenceDefinition::fromXContent,
|
||||
xContentRegistry);
|
||||
listener.onResponse(Tuple.tuple(config, inferenceDefinition));
|
||||
listener.onResponse(inferenceDefinition);
|
||||
},
|
||||
listener::onFailure
|
||||
));
|
||||
|
||||
},
|
||||
listener::onFailure
|
||||
e -> {
|
||||
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
return;
|
||||
}
|
||||
listener.onFailure(e);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionRequest;
|
|||
import org.elasticsearch.action.ActionResponse;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.TransportAction;
|
||||
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.license.LicenseService;
|
||||
|
@ -29,18 +30,22 @@ import java.util.Collections;
|
|||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME;
|
||||
|
||||
public class LocalStateMachineLearning extends LocalStateCompositeXPackPlugin {
|
||||
|
||||
public LocalStateMachineLearning(final Settings settings, final Path configPath) throws Exception {
|
||||
super(settings, configPath);
|
||||
LocalStateMachineLearning thisVar = this;
|
||||
MachineLearning plugin = new MachineLearning(settings, configPath){
|
||||
|
||||
plugins.add(new MachineLearning(settings, configPath){
|
||||
@Override
|
||||
protected XPackLicenseState getLicenseState() {
|
||||
return thisVar.getLicenseState();
|
||||
}
|
||||
});
|
||||
};
|
||||
plugin.setCircuitBreaker(new NoopCircuitBreaker(TRAINED_MODEL_CIRCUIT_BREAKER_NAME));
|
||||
plugins.add(plugin);
|
||||
plugins.add(new Monitoring(settings) {
|
||||
@Override
|
||||
protected SSLService getSslService() {
|
||||
|
|
|
@ -18,13 +18,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
|||
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
|
||||
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.breaker.CircuitBreaker;
|
||||
import org.elasticsearch.common.breaker.CircuitBreakingException;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.transport.TransportAddress;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.common.xcontent.XContentType;
|
||||
|
@ -40,6 +40,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
|
|||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
|
||||
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
|
@ -61,6 +62,7 @@ import java.util.concurrent.TimeUnit;
|
|||
|
||||
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
|
@ -83,6 +85,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
private ClusterService clusterService;
|
||||
private InferenceAuditor auditor;
|
||||
private TrainedModelStatsService trainedModelStatsService;
|
||||
private CircuitBreaker circuitBreaker;
|
||||
|
||||
@Before
|
||||
public void setUpComponents() {
|
||||
|
@ -97,6 +100,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
|
||||
doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class));
|
||||
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build());
|
||||
circuitBreaker = new CustomCircuitBreaker(1000);
|
||||
}
|
||||
|
||||
@After
|
||||
|
@ -116,10 +120,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
|
@ -163,10 +167,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
trainedModelStatsService,
|
||||
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build(),
|
||||
"test-node");
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
// We want to be notified when the models are loaded which happens in a background thread
|
||||
ModelLoadedTracker loadedTracker = new ModelLoadedTracker(Arrays.asList(modelIds));
|
||||
|
@ -279,10 +283,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
|
||||
|
||||
|
@ -304,10 +308,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model));
|
||||
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
|
@ -332,10 +336,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
modelLoadingService.getModel(model, future);
|
||||
|
@ -355,10 +359,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
NamedXContentRegistry.EMPTY,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node");
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
for(int i = 0; i < 3; i++) {
|
||||
PlainActionFuture<Model> future = new PlainActionFuture<>();
|
||||
|
@ -370,6 +374,50 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
||||
public void testCircuitBreakerBreak() throws Exception {
|
||||
String model1 = "test-circuit-break-model-1";
|
||||
String model2 = "test-circuit-break-model-2";
|
||||
String model3 = "test-circuit-break-model-3";
|
||||
withTrainedModel(model1, 5L);
|
||||
withTrainedModel(model2, 5L);
|
||||
withTrainedModel(model3, 12L);
|
||||
CircuitBreaker circuitBreaker = new CustomCircuitBreaker(11);
|
||||
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
|
||||
auditor,
|
||||
threadPool,
|
||||
clusterService,
|
||||
trainedModelStatsService,
|
||||
Settings.EMPTY,
|
||||
"test-node",
|
||||
circuitBreaker);
|
||||
|
||||
modelLoadingService.addModelLoadedListener(model3, ActionListener.wrap(
|
||||
r -> fail("Should not have succeeded to load model as breaker should be reached"),
|
||||
e -> assertThat(e, instanceOf(CircuitBreakingException.class))
|
||||
));
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
|
||||
|
||||
// Should have been loaded from the cluster change event but it is unknown in what order
|
||||
// the loading occurred or which models are currently in the cache due to evictions.
|
||||
// Verify that we have at least loaded all three
|
||||
assertBusy(() -> {
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
|
||||
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
|
||||
});
|
||||
assertBusy(() -> {
|
||||
assertThat(circuitBreaker.getUsed(), equalTo(10L));
|
||||
assertThat(circuitBreaker.getTrippedCount(), equalTo(1L));
|
||||
});
|
||||
|
||||
modelLoadingService.clusterChanged(ingestChangedEvent(model1));
|
||||
|
||||
assertBusy(() -> {
|
||||
assertThat(circuitBreaker.getUsed(), equalTo(5L));
|
||||
});
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void withTrainedModel(String modelId, long size) {
|
||||
InferenceDefinition definition = mock(InferenceDefinition.class);
|
||||
|
@ -378,15 +426,48 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
when(trainedModelConfig.getModelId()).thenReturn(modelId);
|
||||
when(trainedModelConfig.getInferenceConfig()).thenReturn(ClassificationConfig.EMPTY_PARAMS);
|
||||
when(trainedModelConfig.getInput()).thenReturn(new TrainedModelInput(Arrays.asList("foo", "bar", "baz")));
|
||||
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size);
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
||||
listener.onResponse(Tuple.tuple(trainedModelConfig, definition));
|
||||
listener.onResponse(definition);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void withMissingModel(String modelId) {
|
||||
if (randomBoolean()) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
||||
} else {
|
||||
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
|
||||
when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(trainedModelConfig);
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
||||
listener.onFailure(new ResourceNotFoundException(
|
||||
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
||||
return null;
|
||||
}).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
|
||||
}
|
||||
|
||||
private void withMissingModel(String modelId) {
|
||||
doAnswer(invocationOnMock -> {
|
||||
@SuppressWarnings("rawtypes")
|
||||
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
||||
|
@ -438,6 +519,79 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private static class CustomCircuitBreaker implements CircuitBreaker {
|
||||
|
||||
private final long maxBytes;
|
||||
private long currentBytes = 0;
|
||||
private long trippedCount = 0;
|
||||
|
||||
CustomCircuitBreaker(long maxBytes) {
|
||||
this.maxBytes = maxBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void circuitBreak(String fieldName, long bytesNeeded) {
|
||||
throw new CircuitBreakingException(fieldName, Durability.TRANSIENT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
|
||||
synchronized (this) {
|
||||
if (bytes + currentBytes >= maxBytes) {
|
||||
trippedCount++;
|
||||
circuitBreak(label, bytes);
|
||||
}
|
||||
currentBytes += bytes;
|
||||
return currentBytes;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long addWithoutBreaking(long bytes) {
|
||||
synchronized (this) {
|
||||
currentBytes += bytes;
|
||||
return currentBytes;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getUsed() {
|
||||
return currentBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getLimit() {
|
||||
return maxBytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getOverhead() {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTrippedCount() {
|
||||
synchronized (this) {
|
||||
return trippedCount;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Durability getDurability() {
|
||||
return Durability.TRANSIENT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLimitAndOverhead(long limit, double overhead) {
|
||||
throw new UnsupportedOperationException("boom");
|
||||
}
|
||||
}
|
||||
|
||||
private static class ModelLoadedTracker {
|
||||
private final Set<String> expectedModelIds;
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
|
|||
import org.elasticsearch.xpack.ilm.IndexLifecycle;
|
||||
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.monitoring.MonitoringService;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -98,6 +99,8 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
|
|||
settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial");
|
||||
settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false);
|
||||
settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false);
|
||||
settings.put(MonitoringService.ENABLED.getKey(), false);
|
||||
settings.put(MonitoringService.ELASTICSEARCH_COLLECTION_ENABLED.getKey(), false);
|
||||
settings.put(LifecycleSettings.LIFECYCLE_HISTORY_INDEX_ENABLED_SETTING.getKey(), false);
|
||||
return settings.build();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue