[ML] Fixing inference stats race condition (#55163) (#55486)

`updateAndGet` could actually call the internal method more than once on contention.
If I read the JavaDocs, it says:
```* @param updateFunction a side-effect-free function```
So, it could be getting multiple updates on contention, thus having a race condition where stats are double counted.

To fix, I am going to use a `ReadWriteLock`. The `LongAdder` objects allows fast thread safe writes in high contention environments. These can be protected by the `ReadWriteLock::readLock`.

When stats are persisted, I need to call reset on all these adders. This is NOT thread safe if additions are taking place concurrently. So, I am going to protect with `ReadWriteLock::writeLock`.

This should prevent race conditions while allowing high (ish) throughput in the highly contention paths in inference.

I did some simple throughput tests and this change is not significantly slower and is simpler to grok (IMO).

closes  https://github.com/elastic/elasticsearch/issues/54786
This commit is contained in:
Benjamin Trent 2020-04-20 16:21:18 -04:00 committed by GitHub
parent 24d41eb695
commit cabff65aec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 242 additions and 108 deletions

View File

@ -21,6 +21,8 @@ import java.io.IOException;
import java.time.Instant;
import java.util.Objects;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class InferenceStats implements ToXContentObject, Writeable {
@ -204,6 +206,12 @@ public class InferenceStats implements ToXContentObject, Writeable {
private final LongAdder failureCountAccumulator = new LongAdder();
private final String modelId;
private final String nodeId;
// curious reader
// you may be wondering why the lock set to the fair.
// When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute.
// If a ReadWriteLock is unfair, there are no such guarantees.
// A call for the `writelock::lock` could pause indefinitely.
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
public Accumulator(String modelId, String nodeId) {
this.modelId = modelId;
@ -226,22 +234,52 @@ public class InferenceStats implements ToXContentObject, Writeable {
}
public Accumulator incMissingFields() {
this.missingFieldsAccumulator.increment();
return this;
readWriteLock.readLock().lock();
try {
this.missingFieldsAccumulator.increment();
return this;
} finally {
readWriteLock.readLock().unlock();
}
}
public Accumulator incInference() {
this.inferenceAccumulator.increment();
return this;
readWriteLock.readLock().lock();
try {
this.inferenceAccumulator.increment();
return this;
} finally {
readWriteLock.readLock().unlock();
}
}
public Accumulator incFailure() {
this.failureCountAccumulator.increment();
return this;
readWriteLock.readLock().lock();
try {
this.failureCountAccumulator.increment();
return this;
} finally {
readWriteLock.readLock().unlock();
}
}
public InferenceStats currentStats() {
return currentStats(Instant.now());
/**
* Thread safe.
*
* Returns the current stats and resets the values of all the counters.
* @return The current stats
*/
public InferenceStats currentStatsAndReset() {
readWriteLock.writeLock().lock();
try {
InferenceStats stats = currentStats(Instant.now());
this.missingFieldsAccumulator.reset();
this.inferenceAccumulator.reset();
this.failureCountAccumulator.reset();
return stats;
} finally {
readWriteLock.writeLock().unlock();
}
}
public InferenceStats currentStats(Instant timeStamp) {

View File

@ -22,7 +22,9 @@ import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.test.ExternalTestCluster;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.junit.After;
import org.junit.Before;
@ -46,14 +48,16 @@ public class InferenceIngestIT extends ESRestTestCase {
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
@Before
public void createBothModels() throws Exception {
Request request = new Request("PUT", "_ml/inference/test_classification");
request.setJsonEntity(CLASSIFICATION_CONFIG);
client().performRequest(request);
request = new Request("PUT", "_ml/inference/test_regression");
request.setJsonEntity(REGRESSION_CONFIG);
client().performRequest(request);
public void setup() throws Exception {
Request loggingSettings = new Request("PUT", "_cluster/settings");
loggingSettings.setJsonEntity("" +
"{" +
"\"transient\" : {\n" +
" \"logger.org.elasticsearch.xpack.ml.inference\" : \"TRACE\"\n" +
" }" +
"}");
client().performRequest(loggingSettings);
client().performRequest(new Request("GET", "/_cluster/health?wait_for_status=green&timeout=30s"));
}
@Override
@ -64,19 +68,33 @@ public class InferenceIngestIT extends ESRestTestCase {
@After
public void cleanUpData() throws Exception {
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
client().performRequest(new Request("DELETE", InferenceIndexConstants.INDEX_PATTERN));
client().performRequest(new Request("DELETE", MlStatsIndex.indexPattern()));
Request loggingSettings = new Request("PUT", "_cluster/settings");
loggingSettings.setJsonEntity("" +
"{" +
"\"transient\" : {\n" +
" \"logger.org.elasticsearch.xpack.ml.inference\" : null\n" +
" }" +
"}");
client().performRequest(loggingSettings);
ESRestTestCase.waitForPendingTasks(adminClient());
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
}
public void testPathologicalPipelineCreationAndDeletion() throws Exception {
String classificationModelId = "test_pathological_classification";
putModel(classificationModelId, CLASSIFICATION_CONFIG);
String regressionModelId = "test_pathological_regression";
putModel(regressionModelId, REGRESSION_CONFIG);
for (int i = 0; i < 10; i++) {
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
client().performRequest(putPipeline("simple_classification_pipeline",
pipelineDefinition(classificationModelId, "classification")));
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression")));
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
}
@ -94,13 +112,30 @@ public class InferenceIngestIT extends ESRestTestCase {
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
assertBusy(() -> {
try {
Response statsResponse = client().performRequest(new Request("GET",
"_ml/inference/" + classificationModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
} catch (ResponseException ex) {
//this could just mean shard failures.
fail(ex.getMessage());
}
}, 30, TimeUnit.SECONDS);
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
public void testPipelineIngest() throws Exception {
String classificationModelId = "test_classification";
putModel(classificationModelId, CLASSIFICATION_CONFIG);
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
String regressionModelId = "test_regression";
putModel(regressionModelId, REGRESSION_CONFIG);
client().performRequest(putPipeline("simple_classification_pipeline",
pipelineDefinition(classificationModelId, "classification")));
client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression")));
for (int i = 0; i < 10; i++) {
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
@ -131,21 +166,30 @@ public class InferenceIngestIT extends ESRestTestCase {
assertBusy(() -> {
try {
Response statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_classification/_stats"));
Response statsResponse = client().performRequest(new Request("GET",
"_ml/inference/" + classificationModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_regression/_stats"));
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
// can get both
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
String entityString = EntityUtils.toString(statsResponse.getEntity());
assertThat(entityString, containsString("\"inference_count\":15"));
assertThat(entityString, containsString("\"inference_count\":10"));
} catch (ResponseException ex) {
//this could just mean shard failures.
fail(ex.getMessage());
}
}, 30, TimeUnit.SECONDS);
}
public void testSimulate() throws IOException {
String classificationModelId = "test_classification_simulate";
putModel(classificationModelId, CLASSIFICATION_CONFIG);
String regressionModelId = "test_regression_simulate";
putModel(regressionModelId, REGRESSION_CONFIG);
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
@ -157,7 +201,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"top_classes_results_field\": \"result_class_prob\"," +
" \"num_top_feature_importance_values\": 2" +
" }},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"model_id\": \"" + classificationModelId + "\",\n" +
" \"field_map\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
@ -169,7 +213,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"ml.regression\",\n" +
" \"model_id\": \"test_regression\",\n" +
" \"model_id\": \"" + regressionModelId + "\",\n" +
" \"inference_config\": {\"regression\":{}},\n" +
" \"field_map\": {\n" +
" \"col1\": \"col1\",\n" +
@ -232,6 +276,8 @@ public class InferenceIngestIT extends ESRestTestCase {
}
public void testSimulateWithDefaultMappedField() throws IOException {
String classificationModelId = "test_classification_default_mapped_field";
putModel(classificationModelId, CLASSIFICATION_CONFIG);
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
@ -243,7 +289,7 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"top_classes_results_field\": \"result_class_prob\"," +
" \"num_top_feature_importance_values\": 2" +
" }},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"model_id\": \"" + classificationModelId + "\",\n" +
" \"field_map\": {}\n" +
" }\n" +
" }\n"+
@ -607,36 +653,28 @@ public class InferenceIngestIT extends ESRestTestCase {
" \"definition\": " + CLASSIFICATION_DEFINITION +
"}";
private static final String CLASSIFICATION_PIPELINE = "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"model_id\": \"test_classification\",\n" +
" \"tag\": \"classification\",\n" +
" \"inference_config\": {\"classification\": {}},\n" +
" \"field_map\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
private static String pipelineDefinition(String modelId, String inferenceConfig) {
return "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"tag\": \""+ inferenceConfig + "\",\n" +
" \"inference_config\": {\"" + inferenceConfig + "\": {}},\n" +
" \"field_map\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
}
private static final String REGRESSION_PIPELINE = "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"model_id\": \"test_regression\",\n" +
" \"tag\": \"regression\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
" \"field_map\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
private void putModel(String modelId, String modelConfiguration) throws IOException {
Request request = new Request("PUT", "_ml/inference/" + modelId);
request.setJsonEntity(modelConfiguration);
client().performRequest(request);
}
}

View File

@ -9,18 +9,21 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.routing.IndexRoutingTable;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.indices.InvalidAliasNameException;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.threadpool.Scheduler;
@ -28,6 +31,7 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
@ -36,9 +40,11 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@ -68,7 +74,6 @@ public class TrainedModelStatsService {
private final IndexNameExpressionResolver indexNameExpressionResolver;
private final ThreadPool threadPool;
private volatile Scheduler.Cancellable scheduledFuture;
private volatile boolean verifiedStatsIndexCreated;
private volatile boolean stopped;
private volatile ClusterState clusterState;
@ -97,15 +102,26 @@ public class TrainedModelStatsService {
clusterService.addListener((event) -> this.clusterState = event.state());
}
public void queueStats(InferenceStats stats) {
statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()),
(k, previousStats) -> previousStats == null ?
stats :
InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp()));
/**
* Queues the stats for storing.
* @param stats The stats to store or increment
* @param flush When `true`, this indicates that stats should be written as soon as possible.
* If `false`, stats are not persisted until the next periodic persistence action.
*/
public void queueStats(InferenceStats stats, boolean flush) {
if (stats.hasStats()) {
statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()),
(k, previousStats) -> previousStats == null ?
stats :
InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp()));
}
if (flush) {
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(this::updateStats);
}
}
void stop() {
logger.info("About to stop TrainedModelStatsService");
logger.debug("About to stop TrainedModelStatsService");
stopped = true;
statsQueue.clear();
@ -116,7 +132,7 @@ public class TrainedModelStatsService {
}
void start() {
logger.info("About to start TrainedModelStatsService");
logger.debug("About to start TrainedModelStatsService");
stopped = false;
scheduledFuture = threadPool.scheduleWithFixedDelay(this::updateStats,
PERSISTENCE_INTERVAL,
@ -124,27 +140,31 @@ public class TrainedModelStatsService {
}
void updateStats() {
if (clusterState == null || statsQueue.isEmpty()) {
if (clusterState == null || statsQueue.isEmpty() || stopped) {
return;
}
if (verifiedStatsIndexCreated == false) {
logger.info("About to create the stats index as it does not exist yet");
if (verifyIndicesPrimaryShardsAreActive(clusterState, indexNameExpressionResolver) == false) {
try {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener);
listener.actionGet();
verifiedStatsIndexCreated = true;
logger.info("Created stats index");
} catch (Exception e) {
logger.error("failure creating ml stats index for storing model stats", e);
return;
logger.debug("About to create the stats index as it does not exist yet");
createStatsIndexIfNecessary();
} catch(Exception e){
// This exception occurs if, for some reason, the `createStatsIndexAndAliasIfNecessary` fails due to
// a concrete index of the alias name already existing. This error is recoverable eventually, but
// should NOT cause us to lose statistics.
if ((e instanceof InvalidAliasNameException) == false) {
logger.error("failure creating ml stats index for storing model stats", e);
return;
}
}
}
List<InferenceStats> stats = new ArrayList<>(statsQueue.size());
for(String k : statsQueue.keySet()) {
// We want a copy as the underlying concurrent map could be changed while iterating
// We don't want to accidentally grab updates twice
Set<String> keys = new HashSet<>(statsQueue.keySet());
for(String k : keys) {
InferenceStats inferenceStats = statsQueue.remove(k);
if (inferenceStats != null && inferenceStats.hasStats()) {
if (inferenceStats != null) {
stats.add(inferenceStats);
}
}
@ -157,12 +177,46 @@ public class TrainedModelStatsService {
if (bulkRequest.requests().isEmpty()) {
return;
}
if (stopped) {
return;
}
resultsPersisterService.bulkIndexWithRetry(bulkRequest,
stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(",")),
() -> stopped == false,
(msg) -> {});
}
private static boolean verifyIndicesPrimaryShardsAreActive(ClusterState clusterState, IndexNameExpressionResolver expressionResolver) {
String[] indices = expressionResolver.concreteIndexNames(clusterState,
IndicesOptions.LENIENT_EXPAND_OPEN_HIDDEN,
MlStatsIndex.writeAlias());
for (String index : indices) {
if (clusterState.metadata().hasIndex(index) == false) {
return false;
}
IndexRoutingTable routingTable = clusterState.getRoutingTable().index(index);
if (routingTable == null || routingTable.allPrimaryShardsActive() == false) {
return false;
}
}
return true;
}
private void createStatsIndexIfNecessary() {
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener);
listener.actionGet();
listener = new PlainActionFuture<>();
ElasticsearchMappings.addDocMappingIfMissing(
MlStatsIndex.writeAlias(),
MlStatsIndex::mapping,
client,
clusterState,
listener);
listener.actionGet();
logger.debug("Created stats index");
}
static UpdateRequest buildUpdateRequest(InferenceStats stats) {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
Map<String, Object> params = new HashMap<>();
@ -174,6 +228,9 @@ public class TrainedModelStatsService {
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.upsert(builder)
.index(MlStatsIndex.writeAlias())
// Usually, there shouldn't be a conflict, but if there is, only around a single update should have happened
// out of band. If there is MANY more than that, something strange is happening and it should fail.
.retryOnConflict(3)
.id(InferenceStats.docId(stats.getModelId(), stats.getNodeId()))
.script(new Script(ScriptType.INLINE, "painless", STATS_UPDATE_SCRIPT, params));
return updateRequest;

View File

@ -24,7 +24,6 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
@ -36,7 +35,7 @@ public class LocalModel implements Model {
private final String nodeId;
private final Set<String> fieldNames;
private final Map<String, String> defaultFieldMap;
private final AtomicReference<InferenceStats.Accumulator> statsAccumulator;
private final InferenceStats.Accumulator statsAccumulator;
private final TrainedModelStatsService trainedModelStatsService;
private volatile long persistenceQuotient = 100;
private final LongAdder currentInferenceCount;
@ -53,7 +52,7 @@ public class LocalModel implements Model {
this.modelId = modelId;
this.nodeId = nodeId;
this.fieldNames = new HashSet<>(input.getFieldNames());
this.statsAccumulator = new AtomicReference<>(new InferenceStats.Accumulator(modelId, nodeId));
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
this.trainedModelStatsService = trainedModelStatsService;
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder();
@ -71,8 +70,7 @@ public class LocalModel implements Model {
@Override
public InferenceStats getLatestStatsAndReset() {
InferenceStats.Accumulator toPersist = statsAccumulator.getAndSet(new InferenceStats.Accumulator(modelId, nodeId));
return toPersist.currentStats();
return statsAccumulator.currentStatsAndReset();
}
@Override
@ -89,8 +87,8 @@ public class LocalModel implements Model {
}
}
void persistStats() {
trainedModelStatsService.queueStats(getLatestStatsAndReset());
void persistStats(boolean flush) {
trainedModelStatsService.queueStats(getLatestStatsAndReset(), flush);
if (persistenceQuotient < 1000 && currentInferenceCount.sum() > 1000) {
persistenceQuotient = 1000;
}
@ -110,27 +108,27 @@ public class LocalModel implements Model {
return;
}
try {
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incInference);
statsAccumulator.incInference();
currentInferenceCount.increment();
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incMissingFields);
statsAccumulator.incMissingFields();
if (shouldPersistStats) {
persistStats();
persistStats(false);
}
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
return;
}
InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig));
if (shouldPersistStats) {
persistStats();
persistStats(false);
}
listener.onResponse(inferenceResults);
} catch (Exception e) {
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incFailure);
statsAccumulator.incFailure();
listener.onFailure(e);
}
}

View File

@ -275,7 +275,8 @@ public class ModelLoadingService implements ClusterStateListener {
INFERENCE_MODEL_CACHE_TTL.getKey());
auditIfNecessary(notification.getKey(), msg);
}
notification.getValue().persistStats();
// If the model is no longer referenced, flush the stats to persist as soon as possible
notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false);
}
@Override

View File

@ -30,7 +30,7 @@ public class RestGetTrainedModelsStatsAction extends BaseRestHandler {
public List<Route> routes() {
return unmodifiableList(asList(
new Route(GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats"),
new Route(GET, MachineLearning.BASE_PATH + MachineLearning.BASE_PATH + "inference/_stats")));
new Route(GET, MachineLearning.BASE_PATH + "inference/_stats")));
}
@Override

View File

@ -45,6 +45,7 @@ import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.argThat;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
@ -55,7 +56,7 @@ public class LocalModelTests extends ESTestCase {
public void testClassificationInfer() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
@ -126,7 +127,7 @@ public class LocalModelTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testClassificationInferWithDifferentPredictionFieldTypes() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
@ -183,7 +184,7 @@ public class LocalModelTests extends ESTestCase {
public void testRegression() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
@ -209,7 +210,7 @@ public class LocalModelTests extends ESTestCase {
public void testAllFieldsMissing() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
@ -238,7 +239,7 @@ public class LocalModelTests extends ESTestCase {
public void testInferPersistsStatsAfterNumberOfCalls() throws Exception {
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
String modelId = "classification_model";
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
@ -273,7 +274,7 @@ public class LocalModelTests extends ESTestCase {
public boolean matches(Object o) {
return ((InferenceStats)o).getInferenceCount() == 99L;
}
}));
}), anyBoolean());
}
private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model model,

View File

@ -61,6 +61,7 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atMost;
@ -194,7 +195,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
}), anyBoolean());
// Load model 3, should invalidate 1 and 2
for(int i = 0; i < 10; i++) {
@ -209,13 +210,13 @@ public class ModelLoadingServiceTests extends ESTestCase {
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model1);
}
}));
}), anyBoolean());
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
@Override
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model2);
}
}));
}), anyBoolean());
// Load model 1, should invalidate 3
for(int i = 0; i < 10; i++) {
@ -229,7 +230,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
public boolean matches(final Object o) {
return ((InferenceStats)o).getModelId().equals(model3);
}
}));
}), anyBoolean());
// Load model 2
for(int i = 0; i < 10; i++) {
@ -278,7 +279,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
}
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
public void testGetCachedMissingModel() throws Exception {
@ -306,7 +307,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
public void testGetMissingModel() {
@ -352,7 +353,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
}
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
}
@SuppressWarnings("unchecked")