From cabff65aecdce39df6628b6f5f929e34112d79c9 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 20 Apr 2020 16:21:18 -0400 Subject: [PATCH] [ML] Fixing inference stats race condition (#55163) (#55486) `updateAndGet` could actually call the internal method more than once on contention. If I read the JavaDocs, it says: ```* @param updateFunction a side-effect-free function``` So, it could be getting multiple updates on contention, thus having a race condition where stats are double counted. To fix, I am going to use a `ReadWriteLock`. The `LongAdder` objects allows fast thread safe writes in high contention environments. These can be protected by the `ReadWriteLock::readLock`. When stats are persisted, I need to call reset on all these adders. This is NOT thread safe if additions are taking place concurrently. So, I am going to protect with `ReadWriteLock::writeLock`. This should prevent race conditions while allowing high (ish) throughput in the highly contention paths in inference. I did some simple throughput tests and this change is not significantly slower and is simpler to grok (IMO). closes https://github.com/elastic/elasticsearch/issues/54786 --- .../trainedmodel/InferenceStats.java | 54 ++++++- .../ml/integration/InferenceIngestIT.java | 142 +++++++++++------- .../inference/TrainedModelStatsService.java | 99 +++++++++--- .../inference/loadingservice/LocalModel.java | 22 ++- .../loadingservice/ModelLoadingService.java | 3 +- .../RestGetTrainedModelsStatsAction.java | 2 +- .../loadingservice/LocalModelTests.java | 13 +- .../ModelLoadingServiceTests.java | 15 +- 8 files changed, 242 insertions(+), 108 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java index 9a34f21fe42..d14320ed053 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java @@ -21,6 +21,8 @@ import java.io.IOException; import java.time.Instant; import java.util.Objects; import java.util.concurrent.atomic.LongAdder; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; public class InferenceStats implements ToXContentObject, Writeable { @@ -204,6 +206,12 @@ public class InferenceStats implements ToXContentObject, Writeable { private final LongAdder failureCountAccumulator = new LongAdder(); private final String modelId; private final String nodeId; + // curious reader + // you may be wondering why the lock set to the fair. + // When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute. + // If a ReadWriteLock is unfair, there are no such guarantees. + // A call for the `writelock::lock` could pause indefinitely. + private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true); public Accumulator(String modelId, String nodeId) { this.modelId = modelId; @@ -226,22 +234,52 @@ public class InferenceStats implements ToXContentObject, Writeable { } public Accumulator incMissingFields() { - this.missingFieldsAccumulator.increment(); - return this; + readWriteLock.readLock().lock(); + try { + this.missingFieldsAccumulator.increment(); + return this; + } finally { + readWriteLock.readLock().unlock(); + } } public Accumulator incInference() { - this.inferenceAccumulator.increment(); - return this; + readWriteLock.readLock().lock(); + try { + this.inferenceAccumulator.increment(); + return this; + } finally { + readWriteLock.readLock().unlock(); + } } public Accumulator incFailure() { - this.failureCountAccumulator.increment(); - return this; + readWriteLock.readLock().lock(); + try { + this.failureCountAccumulator.increment(); + return this; + } finally { + readWriteLock.readLock().unlock(); + } } - public InferenceStats currentStats() { - return currentStats(Instant.now()); + /** + * Thread safe. + * + * Returns the current stats and resets the values of all the counters. + * @return The current stats + */ + public InferenceStats currentStatsAndReset() { + readWriteLock.writeLock().lock(); + try { + InferenceStats stats = currentStats(Instant.now()); + this.missingFieldsAccumulator.reset(); + this.inferenceAccumulator.reset(); + this.failureCountAccumulator.reset(); + return stats; + } finally { + readWriteLock.writeLock().unlock(); + } } public InferenceStats currentStats(Instant timeStamp) { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index eed952cd940..76ea67d4ad4 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -22,7 +22,9 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.test.ExternalTestCluster; import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; import org.junit.After; import org.junit.Before; @@ -46,14 +48,16 @@ public class InferenceIngestIT extends ESRestTestCase { basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING); @Before - public void createBothModels() throws Exception { - Request request = new Request("PUT", "_ml/inference/test_classification"); - request.setJsonEntity(CLASSIFICATION_CONFIG); - client().performRequest(request); - - request = new Request("PUT", "_ml/inference/test_regression"); - request.setJsonEntity(REGRESSION_CONFIG); - client().performRequest(request); + public void setup() throws Exception { + Request loggingSettings = new Request("PUT", "_cluster/settings"); + loggingSettings.setJsonEntity("" + + "{" + + "\"transient\" : {\n" + + " \"logger.org.elasticsearch.xpack.ml.inference\" : \"TRACE\"\n" + + " }" + + "}"); + client().performRequest(loggingSettings); + client().performRequest(new Request("GET", "/_cluster/health?wait_for_status=green&timeout=30s")); } @Override @@ -64,19 +68,33 @@ public class InferenceIngestIT extends ESRestTestCase { @After public void cleanUpData() throws Exception { new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata(); + client().performRequest(new Request("DELETE", InferenceIndexConstants.INDEX_PATTERN)); + client().performRequest(new Request("DELETE", MlStatsIndex.indexPattern())); + Request loggingSettings = new Request("PUT", "_cluster/settings"); + loggingSettings.setJsonEntity("" + + "{" + + "\"transient\" : {\n" + + " \"logger.org.elasticsearch.xpack.ml.inference\" : null\n" + + " }" + + "}"); + client().performRequest(loggingSettings); ESRestTestCase.waitForPendingTasks(adminClient()); - client().performRequest(new Request("DELETE", "_ml/inference/test_classification")); - client().performRequest(new Request("DELETE", "_ml/inference/test_regression")); } public void testPathologicalPipelineCreationAndDeletion() throws Exception { + String classificationModelId = "test_pathological_classification"; + putModel(classificationModelId, CLASSIFICATION_CONFIG); + + String regressionModelId = "test_pathological_regression"; + putModel(regressionModelId, REGRESSION_CONFIG); for (int i = 0; i < 10; i++) { - client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); + client().performRequest(putPipeline("simple_classification_pipeline", + pipelineDefinition(classificationModelId, "classification"))); client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc())); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline")); - client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE)); + client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression"))); client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); } @@ -94,13 +112,30 @@ public class InferenceIngestIT extends ESRestTestCase { QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))); assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10")); + assertBusy(() -> { + try { + Response statsResponse = client().performRequest(new Request("GET", + "_ml/inference/" + classificationModelId + "/_stats")); + assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); + statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats")); + assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); + } catch (ResponseException ex) { + //this could just mean shard failures. + fail(ex.getMessage()); + } + }, 30, TimeUnit.SECONDS); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786") public void testPipelineIngest() throws Exception { + String classificationModelId = "test_classification"; + putModel(classificationModelId, CLASSIFICATION_CONFIG); - client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); - client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE)); + String regressionModelId = "test_regression"; + putModel(regressionModelId, REGRESSION_CONFIG); + + client().performRequest(putPipeline("simple_classification_pipeline", + pipelineDefinition(classificationModelId, "classification"))); + client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression"))); for (int i = 0; i < 10; i++) { client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc())); @@ -131,21 +166,30 @@ public class InferenceIngestIT extends ESRestTestCase { assertBusy(() -> { try { - Response statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_classification/_stats")); + Response statsResponse = client().performRequest(new Request("GET", + "_ml/inference/" + classificationModelId + "/_stats")); assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); - statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_regression/_stats")); + statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats")); assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15")); // can get both statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats")); - assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15")); - assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10")); + String entityString = EntityUtils.toString(statsResponse.getEntity()); + assertThat(entityString, containsString("\"inference_count\":15")); + assertThat(entityString, containsString("\"inference_count\":10")); } catch (ResponseException ex) { //this could just mean shard failures. + fail(ex.getMessage()); } }, 30, TimeUnit.SECONDS); } public void testSimulate() throws IOException { + String classificationModelId = "test_classification_simulate"; + putModel(classificationModelId, CLASSIFICATION_CONFIG); + + String regressionModelId = "test_regression_simulate"; + putModel(regressionModelId, REGRESSION_CONFIG); + String source = "{\n" + " \"pipeline\": {\n" + " \"processors\": [\n" + @@ -157,7 +201,7 @@ public class InferenceIngestIT extends ESRestTestCase { " \"top_classes_results_field\": \"result_class_prob\"," + " \"num_top_feature_importance_values\": 2" + " }},\n" + - " \"model_id\": \"test_classification\",\n" + + " \"model_id\": \"" + classificationModelId + "\",\n" + " \"field_map\": {\n" + " \"col1\": \"col1\",\n" + " \"col2\": \"col2\",\n" + @@ -169,7 +213,7 @@ public class InferenceIngestIT extends ESRestTestCase { " {\n" + " \"inference\": {\n" + " \"target_field\": \"ml.regression\",\n" + - " \"model_id\": \"test_regression\",\n" + + " \"model_id\": \"" + regressionModelId + "\",\n" + " \"inference_config\": {\"regression\":{}},\n" + " \"field_map\": {\n" + " \"col1\": \"col1\",\n" + @@ -232,6 +276,8 @@ public class InferenceIngestIT extends ESRestTestCase { } public void testSimulateWithDefaultMappedField() throws IOException { + String classificationModelId = "test_classification_default_mapped_field"; + putModel(classificationModelId, CLASSIFICATION_CONFIG); String source = "{\n" + " \"pipeline\": {\n" + " \"processors\": [\n" + @@ -243,7 +289,7 @@ public class InferenceIngestIT extends ESRestTestCase { " \"top_classes_results_field\": \"result_class_prob\"," + " \"num_top_feature_importance_values\": 2" + " }},\n" + - " \"model_id\": \"test_classification\",\n" + + " \"model_id\": \"" + classificationModelId + "\",\n" + " \"field_map\": {}\n" + " }\n" + " }\n"+ @@ -607,36 +653,28 @@ public class InferenceIngestIT extends ESRestTestCase { " \"definition\": " + CLASSIFICATION_DEFINITION + "}"; - private static final String CLASSIFICATION_PIPELINE = "{" + - " \"processors\": [\n" + - " {\n" + - " \"inference\": {\n" + - " \"model_id\": \"test_classification\",\n" + - " \"tag\": \"classification\",\n" + - " \"inference_config\": {\"classification\": {}},\n" + - " \"field_map\": {\n" + - " \"col1\": \"col1\",\n" + - " \"col2\": \"col2\",\n" + - " \"col3\": \"col3\",\n" + - " \"col4\": \"col4\"\n" + - " }\n" + - " }\n" + - " }]}\n"; + private static String pipelineDefinition(String modelId, String inferenceConfig) { + return "{" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"model_id\": \"" + modelId + "\",\n" + + " \"tag\": \""+ inferenceConfig + "\",\n" + + " \"inference_config\": {\"" + inferenceConfig + "\": {}},\n" + + " \"field_map\": {\n" + + " \"col1\": \"col1\",\n" + + " \"col2\": \"col2\",\n" + + " \"col3\": \"col3\",\n" + + " \"col4\": \"col4\"\n" + + " }\n" + + " }\n" + + " }]}\n"; + } - private static final String REGRESSION_PIPELINE = "{" + - " \"processors\": [\n" + - " {\n" + - " \"inference\": {\n" + - " \"model_id\": \"test_regression\",\n" + - " \"tag\": \"regression\",\n" + - " \"inference_config\": {\"regression\": {}},\n" + - " \"field_map\": {\n" + - " \"col1\": \"col1\",\n" + - " \"col2\": \"col2\",\n" + - " \"col3\": \"col3\",\n" + - " \"col4\": \"col4\"\n" + - " }\n" + - " }\n" + - " }]}\n"; + private void putModel(String modelId, String modelConfiguration) throws IOException { + Request request = new Request("PUT", "_ml/inference/" + modelId); + request.setJsonEntity(modelConfiguration); + client().performRequest(request); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java index 7a6cc7974c3..b255ff8978d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/TrainedModelStatsService.java @@ -9,18 +9,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.indices.InvalidAliasNameException; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; import org.elasticsearch.threadpool.Scheduler; @@ -28,6 +31,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; @@ -36,9 +40,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -68,7 +74,6 @@ public class TrainedModelStatsService { private final IndexNameExpressionResolver indexNameExpressionResolver; private final ThreadPool threadPool; private volatile Scheduler.Cancellable scheduledFuture; - private volatile boolean verifiedStatsIndexCreated; private volatile boolean stopped; private volatile ClusterState clusterState; @@ -97,15 +102,26 @@ public class TrainedModelStatsService { clusterService.addListener((event) -> this.clusterState = event.state()); } - public void queueStats(InferenceStats stats) { - statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()), - (k, previousStats) -> previousStats == null ? - stats : - InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp())); + /** + * Queues the stats for storing. + * @param stats The stats to store or increment + * @param flush When `true`, this indicates that stats should be written as soon as possible. + * If `false`, stats are not persisted until the next periodic persistence action. + */ + public void queueStats(InferenceStats stats, boolean flush) { + if (stats.hasStats()) { + statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()), + (k, previousStats) -> previousStats == null ? + stats : + InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp())); + } + if (flush) { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(this::updateStats); + } } void stop() { - logger.info("About to stop TrainedModelStatsService"); + logger.debug("About to stop TrainedModelStatsService"); stopped = true; statsQueue.clear(); @@ -116,7 +132,7 @@ public class TrainedModelStatsService { } void start() { - logger.info("About to start TrainedModelStatsService"); + logger.debug("About to start TrainedModelStatsService"); stopped = false; scheduledFuture = threadPool.scheduleWithFixedDelay(this::updateStats, PERSISTENCE_INTERVAL, @@ -124,27 +140,31 @@ public class TrainedModelStatsService { } void updateStats() { - if (clusterState == null || statsQueue.isEmpty()) { + if (clusterState == null || statsQueue.isEmpty() || stopped) { return; } - if (verifiedStatsIndexCreated == false) { - logger.info("About to create the stats index as it does not exist yet"); + if (verifyIndicesPrimaryShardsAreActive(clusterState, indexNameExpressionResolver) == false) { try { - PlainActionFuture listener = new PlainActionFuture<>(); - MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener); - listener.actionGet(); - verifiedStatsIndexCreated = true; - logger.info("Created stats index"); - } catch (Exception e) { - logger.error("failure creating ml stats index for storing model stats", e); - return; + logger.debug("About to create the stats index as it does not exist yet"); + createStatsIndexIfNecessary(); + } catch(Exception e){ + // This exception occurs if, for some reason, the `createStatsIndexAndAliasIfNecessary` fails due to + // a concrete index of the alias name already existing. This error is recoverable eventually, but + // should NOT cause us to lose statistics. + if ((e instanceof InvalidAliasNameException) == false) { + logger.error("failure creating ml stats index for storing model stats", e); + return; + } } } List stats = new ArrayList<>(statsQueue.size()); - for(String k : statsQueue.keySet()) { + // We want a copy as the underlying concurrent map could be changed while iterating + // We don't want to accidentally grab updates twice + Set keys = new HashSet<>(statsQueue.keySet()); + for(String k : keys) { InferenceStats inferenceStats = statsQueue.remove(k); - if (inferenceStats != null && inferenceStats.hasStats()) { + if (inferenceStats != null) { stats.add(inferenceStats); } } @@ -157,12 +177,46 @@ public class TrainedModelStatsService { if (bulkRequest.requests().isEmpty()) { return; } + if (stopped) { + return; + } resultsPersisterService.bulkIndexWithRetry(bulkRequest, stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(",")), () -> stopped == false, (msg) -> {}); } + private static boolean verifyIndicesPrimaryShardsAreActive(ClusterState clusterState, IndexNameExpressionResolver expressionResolver) { + String[] indices = expressionResolver.concreteIndexNames(clusterState, + IndicesOptions.LENIENT_EXPAND_OPEN_HIDDEN, + MlStatsIndex.writeAlias()); + for (String index : indices) { + if (clusterState.metadata().hasIndex(index) == false) { + return false; + } + IndexRoutingTable routingTable = clusterState.getRoutingTable().index(index); + if (routingTable == null || routingTable.allPrimaryShardsActive() == false) { + return false; + } + } + return true; + } + + private void createStatsIndexIfNecessary() { + PlainActionFuture listener = new PlainActionFuture<>(); + MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener); + listener.actionGet(); + listener = new PlainActionFuture<>(); + ElasticsearchMappings.addDocMappingIfMissing( + MlStatsIndex.writeAlias(), + MlStatsIndex::mapping, + client, + clusterState, + listener); + listener.actionGet(); + logger.debug("Created stats index"); + } + static UpdateRequest buildUpdateRequest(InferenceStats stats) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { Map params = new HashMap<>(); @@ -174,6 +228,9 @@ public class TrainedModelStatsService { UpdateRequest updateRequest = new UpdateRequest(); updateRequest.upsert(builder) .index(MlStatsIndex.writeAlias()) + // Usually, there shouldn't be a conflict, but if there is, only around a single update should have happened + // out of band. If there is MANY more than that, something strange is happening and it should fail. + .retryOnConflict(3) .id(InferenceStats.docId(stats.getModelId(), stats.getNodeId())) .script(new Script(ScriptType.INLINE, "painless", STATS_UPDATE_SCRIPT, params)); return updateRequest; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 53fe0826ddb..e7da7a36184 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -24,7 +24,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING; @@ -36,7 +35,7 @@ public class LocalModel implements Model { private final String nodeId; private final Set fieldNames; private final Map defaultFieldMap; - private final AtomicReference statsAccumulator; + private final InferenceStats.Accumulator statsAccumulator; private final TrainedModelStatsService trainedModelStatsService; private volatile long persistenceQuotient = 100; private final LongAdder currentInferenceCount; @@ -53,7 +52,7 @@ public class LocalModel implements Model { this.modelId = modelId; this.nodeId = nodeId; this.fieldNames = new HashSet<>(input.getFieldNames()); - this.statsAccumulator = new AtomicReference<>(new InferenceStats.Accumulator(modelId, nodeId)); + this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId); this.trainedModelStatsService = trainedModelStatsService; this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap); this.currentInferenceCount = new LongAdder(); @@ -71,8 +70,7 @@ public class LocalModel implements Model { @Override public InferenceStats getLatestStatsAndReset() { - InferenceStats.Accumulator toPersist = statsAccumulator.getAndSet(new InferenceStats.Accumulator(modelId, nodeId)); - return toPersist.currentStats(); + return statsAccumulator.currentStatsAndReset(); } @Override @@ -89,8 +87,8 @@ public class LocalModel implements Model { } } - void persistStats() { - trainedModelStatsService.queueStats(getLatestStatsAndReset()); + void persistStats(boolean flush) { + trainedModelStatsService.queueStats(getLatestStatsAndReset(), flush); if (persistenceQuotient < 1000 && currentInferenceCount.sum() > 1000) { persistenceQuotient = 1000; } @@ -110,27 +108,27 @@ public class LocalModel implements Model { return; } try { - statsAccumulator.updateAndGet(InferenceStats.Accumulator::incInference); + statsAccumulator.incInference(); currentInferenceCount.increment(); Model.mapFieldsIfNecessary(fields, defaultFieldMap); boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0); if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) { - statsAccumulator.updateAndGet(InferenceStats.Accumulator::incMissingFields); + statsAccumulator.incMissingFields(); if (shouldPersistStats) { - persistStats(); + persistStats(false); } listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId))); return; } InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig)); if (shouldPersistStats) { - persistStats(); + persistStats(false); } listener.onResponse(inferenceResults); } catch (Exception e) { - statsAccumulator.updateAndGet(InferenceStats.Accumulator::incFailure); + statsAccumulator.incFailure(); listener.onFailure(e); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 66f3a441dc7..842e48ad1e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -275,7 +275,8 @@ public class ModelLoadingService implements ClusterStateListener { INFERENCE_MODEL_CACHE_TTL.getKey()); auditIfNecessary(notification.getKey(), msg); } - notification.getValue().persistStats(); + // If the model is no longer referenced, flush the stats to persist as soon as possible + notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java index 845cecb96a7..4016f7410be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsStatsAction.java @@ -30,7 +30,7 @@ public class RestGetTrainedModelsStatsAction extends BaseRestHandler { public List routes() { return unmodifiableList(asList( new Route(GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats"), - new Route(GET, MachineLearning.BASE_PATH + MachineLearning.BASE_PATH + "inference/_stats"))); + new Route(GET, MachineLearning.BASE_PATH + "inference/_stats"))); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 9b93e69188a..488c66dc370 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -45,6 +45,7 @@ import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.argThat; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -55,7 +56,7 @@ public class LocalModelTests extends ESTestCase { public void testClassificationInfer() throws Exception { TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); - doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); + doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean()); String modelId = "classification_model"; List inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical"); TrainedModelDefinition definition = new TrainedModelDefinition.Builder() @@ -126,7 +127,7 @@ public class LocalModelTests extends ESTestCase { @SuppressWarnings("unchecked") public void testClassificationInferWithDifferentPredictionFieldTypes() throws Exception { TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); - doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); + doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean()); String modelId = "classification_model"; List inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical"); TrainedModelDefinition definition = new TrainedModelDefinition.Builder() @@ -183,7 +184,7 @@ public class LocalModelTests extends ESTestCase { public void testRegression() throws Exception { TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); - doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); + doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean()); List inputFields = Arrays.asList("foo", "bar", "categorical"); TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) @@ -209,7 +210,7 @@ public class LocalModelTests extends ESTestCase { public void testAllFieldsMissing() throws Exception { TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); - doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); + doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean()); List inputFields = Arrays.asList("foo", "bar", "categorical"); TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) @@ -238,7 +239,7 @@ public class LocalModelTests extends ESTestCase { public void testInferPersistsStatsAfterNumberOfCalls() throws Exception { TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class); - doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class)); + doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean()); String modelId = "classification_model"; List inputFields = Arrays.asList("field.foo", "field.bar", "categorical"); TrainedModelDefinition definition = new TrainedModelDefinition.Builder() @@ -273,7 +274,7 @@ public class LocalModelTests extends ESTestCase { public boolean matches(Object o) { return ((InferenceStats)o).getInferenceCount() == 99L; } - })); + }), anyBoolean()); } private static SingleValueInferenceResults getSingleValue(Model model, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 02be4d69545..629f5d90a84 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -61,6 +61,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.argThat; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atMost; @@ -194,7 +195,7 @@ public class ModelLoadingServiceTests extends ESTestCase { public boolean matches(final Object o) { return ((InferenceStats)o).getModelId().equals(model3); } - })); + }), anyBoolean()); // Load model 3, should invalidate 1 and 2 for(int i = 0; i < 10; i++) { @@ -209,13 +210,13 @@ public class ModelLoadingServiceTests extends ESTestCase { public boolean matches(final Object o) { return ((InferenceStats)o).getModelId().equals(model1); } - })); + }), anyBoolean()); verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher() { @Override public boolean matches(final Object o) { return ((InferenceStats)o).getModelId().equals(model2); } - })); + }), anyBoolean()); // Load model 1, should invalidate 3 for(int i = 0; i < 10; i++) { @@ -229,7 +230,7 @@ public class ModelLoadingServiceTests extends ESTestCase { public boolean matches(final Object o) { return ((InferenceStats)o).getModelId().equals(model3); } - })); + }), anyBoolean()); // Load model 2 for(int i = 0; i < 10; i++) { @@ -278,7 +279,7 @@ public class ModelLoadingServiceTests extends ESTestCase { } verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any()); - verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class)); + verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } public void testGetCachedMissingModel() throws Exception { @@ -306,7 +307,7 @@ public class ModelLoadingServiceTests extends ESTestCase { } verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any()); - verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class)); + verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } public void testGetMissingModel() { @@ -352,7 +353,7 @@ public class ModelLoadingServiceTests extends ESTestCase { } verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any()); - verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class)); + verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @SuppressWarnings("unchecked")