mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-17 02:14:54 +00:00
`updateAndGet` could actually call the internal method more than once on contention. If I read the JavaDocs, it says: ```* @param updateFunction a side-effect-free function``` So, it could be getting multiple updates on contention, thus having a race condition where stats are double counted. To fix, I am going to use a `ReadWriteLock`. The `LongAdder` objects allows fast thread safe writes in high contention environments. These can be protected by the `ReadWriteLock::readLock`. When stats are persisted, I need to call reset on all these adders. This is NOT thread safe if additions are taking place concurrently. So, I am going to protect with `ReadWriteLock::writeLock`. This should prevent race conditions while allowing high (ish) throughput in the highly contention paths in inference. I did some simple throughput tests and this change is not significantly slower and is simpler to grok (IMO). closes https://github.com/elastic/elasticsearch/issues/54786
This commit is contained in:
parent
24d41eb695
commit
cabff65aec
@ -21,6 +21,8 @@ import java.io.IOException;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
import java.util.concurrent.locks.ReadWriteLock;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
|
||||
public class InferenceStats implements ToXContentObject, Writeable {
|
||||
|
||||
@ -204,6 +206,12 @@ public class InferenceStats implements ToXContentObject, Writeable {
|
||||
private final LongAdder failureCountAccumulator = new LongAdder();
|
||||
private final String modelId;
|
||||
private final String nodeId;
|
||||
// curious reader
|
||||
// you may be wondering why the lock set to the fair.
|
||||
// When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute.
|
||||
// If a ReadWriteLock is unfair, there are no such guarantees.
|
||||
// A call for the `writelock::lock` could pause indefinitely.
|
||||
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
|
||||
|
||||
public Accumulator(String modelId, String nodeId) {
|
||||
this.modelId = modelId;
|
||||
@ -226,22 +234,52 @@ public class InferenceStats implements ToXContentObject, Writeable {
|
||||
}
|
||||
|
||||
public Accumulator incMissingFields() {
|
||||
this.missingFieldsAccumulator.increment();
|
||||
return this;
|
||||
readWriteLock.readLock().lock();
|
||||
try {
|
||||
this.missingFieldsAccumulator.increment();
|
||||
return this;
|
||||
} finally {
|
||||
readWriteLock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Accumulator incInference() {
|
||||
this.inferenceAccumulator.increment();
|
||||
return this;
|
||||
readWriteLock.readLock().lock();
|
||||
try {
|
||||
this.inferenceAccumulator.increment();
|
||||
return this;
|
||||
} finally {
|
||||
readWriteLock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Accumulator incFailure() {
|
||||
this.failureCountAccumulator.increment();
|
||||
return this;
|
||||
readWriteLock.readLock().lock();
|
||||
try {
|
||||
this.failureCountAccumulator.increment();
|
||||
return this;
|
||||
} finally {
|
||||
readWriteLock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceStats currentStats() {
|
||||
return currentStats(Instant.now());
|
||||
/**
|
||||
* Thread safe.
|
||||
*
|
||||
* Returns the current stats and resets the values of all the counters.
|
||||
* @return The current stats
|
||||
*/
|
||||
public InferenceStats currentStatsAndReset() {
|
||||
readWriteLock.writeLock().lock();
|
||||
try {
|
||||
InferenceStats stats = currentStats(Instant.now());
|
||||
this.missingFieldsAccumulator.reset();
|
||||
this.inferenceAccumulator.reset();
|
||||
this.failureCountAccumulator.reset();
|
||||
return stats;
|
||||
} finally {
|
||||
readWriteLock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceStats currentStats(Instant timeStamp) {
|
||||
|
@ -22,7 +22,9 @@ import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.test.ExternalTestCluster;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
@ -46,14 +48,16 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
|
||||
@Before
|
||||
public void createBothModels() throws Exception {
|
||||
Request request = new Request("PUT", "_ml/inference/test_classification");
|
||||
request.setJsonEntity(CLASSIFICATION_CONFIG);
|
||||
client().performRequest(request);
|
||||
|
||||
request = new Request("PUT", "_ml/inference/test_regression");
|
||||
request.setJsonEntity(REGRESSION_CONFIG);
|
||||
client().performRequest(request);
|
||||
public void setup() throws Exception {
|
||||
Request loggingSettings = new Request("PUT", "_cluster/settings");
|
||||
loggingSettings.setJsonEntity("" +
|
||||
"{" +
|
||||
"\"transient\" : {\n" +
|
||||
" \"logger.org.elasticsearch.xpack.ml.inference\" : \"TRACE\"\n" +
|
||||
" }" +
|
||||
"}");
|
||||
client().performRequest(loggingSettings);
|
||||
client().performRequest(new Request("GET", "/_cluster/health?wait_for_status=green&timeout=30s"));
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -64,19 +68,33 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
@After
|
||||
public void cleanUpData() throws Exception {
|
||||
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
|
||||
client().performRequest(new Request("DELETE", InferenceIndexConstants.INDEX_PATTERN));
|
||||
client().performRequest(new Request("DELETE", MlStatsIndex.indexPattern()));
|
||||
Request loggingSettings = new Request("PUT", "_cluster/settings");
|
||||
loggingSettings.setJsonEntity("" +
|
||||
"{" +
|
||||
"\"transient\" : {\n" +
|
||||
" \"logger.org.elasticsearch.xpack.ml.inference\" : null\n" +
|
||||
" }" +
|
||||
"}");
|
||||
client().performRequest(loggingSettings);
|
||||
ESRestTestCase.waitForPendingTasks(adminClient());
|
||||
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
|
||||
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
|
||||
}
|
||||
|
||||
public void testPathologicalPipelineCreationAndDeletion() throws Exception {
|
||||
String classificationModelId = "test_pathological_classification";
|
||||
putModel(classificationModelId, CLASSIFICATION_CONFIG);
|
||||
|
||||
String regressionModelId = "test_pathological_regression";
|
||||
putModel(regressionModelId, REGRESSION_CONFIG);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
|
||||
client().performRequest(putPipeline("simple_classification_pipeline",
|
||||
pipelineDefinition(classificationModelId, "classification")));
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
|
||||
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression")));
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
|
||||
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
|
||||
}
|
||||
@ -94,13 +112,30 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
|
||||
|
||||
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
|
||||
assertBusy(() -> {
|
||||
try {
|
||||
Response statsResponse = client().performRequest(new Request("GET",
|
||||
"_ml/inference/" + classificationModelId + "/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
|
||||
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
|
||||
} catch (ResponseException ex) {
|
||||
//this could just mean shard failures.
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
}, 30, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
|
||||
public void testPipelineIngest() throws Exception {
|
||||
String classificationModelId = "test_classification";
|
||||
putModel(classificationModelId, CLASSIFICATION_CONFIG);
|
||||
|
||||
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
|
||||
String regressionModelId = "test_regression";
|
||||
putModel(regressionModelId, REGRESSION_CONFIG);
|
||||
|
||||
client().performRequest(putPipeline("simple_classification_pipeline",
|
||||
pipelineDefinition(classificationModelId, "classification")));
|
||||
client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression")));
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
|
||||
@ -131,21 +166,30 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
|
||||
assertBusy(() -> {
|
||||
try {
|
||||
Response statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_classification/_stats"));
|
||||
Response statsResponse = client().performRequest(new Request("GET",
|
||||
"_ml/inference/" + classificationModelId + "/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
|
||||
statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_regression/_stats"));
|
||||
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
|
||||
// can get both
|
||||
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
|
||||
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
|
||||
String entityString = EntityUtils.toString(statsResponse.getEntity());
|
||||
assertThat(entityString, containsString("\"inference_count\":15"));
|
||||
assertThat(entityString, containsString("\"inference_count\":10"));
|
||||
} catch (ResponseException ex) {
|
||||
//this could just mean shard failures.
|
||||
fail(ex.getMessage());
|
||||
}
|
||||
}, 30, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
public void testSimulate() throws IOException {
|
||||
String classificationModelId = "test_classification_simulate";
|
||||
putModel(classificationModelId, CLASSIFICATION_CONFIG);
|
||||
|
||||
String regressionModelId = "test_regression_simulate";
|
||||
putModel(regressionModelId, REGRESSION_CONFIG);
|
||||
|
||||
String source = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"processors\": [\n" +
|
||||
@ -157,7 +201,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"top_classes_results_field\": \"result_class_prob\"," +
|
||||
" \"num_top_feature_importance_values\": 2" +
|
||||
" }},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"model_id\": \"" + classificationModelId + "\",\n" +
|
||||
" \"field_map\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
@ -169,7 +213,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"target_field\": \"ml.regression\",\n" +
|
||||
" \"model_id\": \"test_regression\",\n" +
|
||||
" \"model_id\": \"" + regressionModelId + "\",\n" +
|
||||
" \"inference_config\": {\"regression\":{}},\n" +
|
||||
" \"field_map\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
@ -232,6 +276,8 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
}
|
||||
|
||||
public void testSimulateWithDefaultMappedField() throws IOException {
|
||||
String classificationModelId = "test_classification_default_mapped_field";
|
||||
putModel(classificationModelId, CLASSIFICATION_CONFIG);
|
||||
String source = "{\n" +
|
||||
" \"pipeline\": {\n" +
|
||||
" \"processors\": [\n" +
|
||||
@ -243,7 +289,7 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"top_classes_results_field\": \"result_class_prob\"," +
|
||||
" \"num_top_feature_importance_values\": 2" +
|
||||
" }},\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"model_id\": \"" + classificationModelId + "\",\n" +
|
||||
" \"field_map\": {}\n" +
|
||||
" }\n" +
|
||||
" }\n"+
|
||||
@ -607,36 +653,28 @@ public class InferenceIngestIT extends ESRestTestCase {
|
||||
" \"definition\": " + CLASSIFICATION_DEFINITION +
|
||||
"}";
|
||||
|
||||
private static final String CLASSIFICATION_PIPELINE = "{" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"model_id\": \"test_classification\",\n" +
|
||||
" \"tag\": \"classification\",\n" +
|
||||
" \"inference_config\": {\"classification\": {}},\n" +
|
||||
" \"field_map\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
private static String pipelineDefinition(String modelId, String inferenceConfig) {
|
||||
return "{" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"model_id\": \"" + modelId + "\",\n" +
|
||||
" \"tag\": \""+ inferenceConfig + "\",\n" +
|
||||
" \"inference_config\": {\"" + inferenceConfig + "\": {}},\n" +
|
||||
" \"field_map\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
}
|
||||
|
||||
private static final String REGRESSION_PIPELINE = "{" +
|
||||
" \"processors\": [\n" +
|
||||
" {\n" +
|
||||
" \"inference\": {\n" +
|
||||
" \"model_id\": \"test_regression\",\n" +
|
||||
" \"tag\": \"regression\",\n" +
|
||||
" \"inference_config\": {\"regression\": {}},\n" +
|
||||
" \"field_map\": {\n" +
|
||||
" \"col1\": \"col1\",\n" +
|
||||
" \"col2\": \"col2\",\n" +
|
||||
" \"col3\": \"col3\",\n" +
|
||||
" \"col4\": \"col4\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
" }]}\n";
|
||||
private void putModel(String modelId, String modelConfiguration) throws IOException {
|
||||
Request request = new Request("PUT", "_ml/inference/" + modelId);
|
||||
request.setJsonEntity(modelConfiguration);
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -9,18 +9,21 @@ import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.action.update.UpdateRequest;
|
||||
import org.elasticsearch.client.OriginSettingClient;
|
||||
import org.elasticsearch.cluster.ClusterState;
|
||||
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
||||
import org.elasticsearch.cluster.routing.IndexRoutingTable;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.component.LifecycleListener;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.ToXContent;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.indices.InvalidAliasNameException;
|
||||
import org.elasticsearch.script.Script;
|
||||
import org.elasticsearch.script.ScriptType;
|
||||
import org.elasticsearch.threadpool.Scheduler;
|
||||
@ -28,6 +31,7 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
@ -36,9 +40,11 @@ import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -68,7 +74,6 @@ public class TrainedModelStatsService {
|
||||
private final IndexNameExpressionResolver indexNameExpressionResolver;
|
||||
private final ThreadPool threadPool;
|
||||
private volatile Scheduler.Cancellable scheduledFuture;
|
||||
private volatile boolean verifiedStatsIndexCreated;
|
||||
private volatile boolean stopped;
|
||||
private volatile ClusterState clusterState;
|
||||
|
||||
@ -97,15 +102,26 @@ public class TrainedModelStatsService {
|
||||
clusterService.addListener((event) -> this.clusterState = event.state());
|
||||
}
|
||||
|
||||
public void queueStats(InferenceStats stats) {
|
||||
statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()),
|
||||
(k, previousStats) -> previousStats == null ?
|
||||
stats :
|
||||
InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp()));
|
||||
/**
|
||||
* Queues the stats for storing.
|
||||
* @param stats The stats to store or increment
|
||||
* @param flush When `true`, this indicates that stats should be written as soon as possible.
|
||||
* If `false`, stats are not persisted until the next periodic persistence action.
|
||||
*/
|
||||
public void queueStats(InferenceStats stats, boolean flush) {
|
||||
if (stats.hasStats()) {
|
||||
statsQueue.compute(InferenceStats.docId(stats.getModelId(), stats.getNodeId()),
|
||||
(k, previousStats) -> previousStats == null ?
|
||||
stats :
|
||||
InferenceStats.accumulator(stats).merge(previousStats).currentStats(stats.getTimeStamp()));
|
||||
}
|
||||
if (flush) {
|
||||
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(this::updateStats);
|
||||
}
|
||||
}
|
||||
|
||||
void stop() {
|
||||
logger.info("About to stop TrainedModelStatsService");
|
||||
logger.debug("About to stop TrainedModelStatsService");
|
||||
stopped = true;
|
||||
statsQueue.clear();
|
||||
|
||||
@ -116,7 +132,7 @@ public class TrainedModelStatsService {
|
||||
}
|
||||
|
||||
void start() {
|
||||
logger.info("About to start TrainedModelStatsService");
|
||||
logger.debug("About to start TrainedModelStatsService");
|
||||
stopped = false;
|
||||
scheduledFuture = threadPool.scheduleWithFixedDelay(this::updateStats,
|
||||
PERSISTENCE_INTERVAL,
|
||||
@ -124,27 +140,31 @@ public class TrainedModelStatsService {
|
||||
}
|
||||
|
||||
void updateStats() {
|
||||
if (clusterState == null || statsQueue.isEmpty()) {
|
||||
if (clusterState == null || statsQueue.isEmpty() || stopped) {
|
||||
return;
|
||||
}
|
||||
if (verifiedStatsIndexCreated == false) {
|
||||
logger.info("About to create the stats index as it does not exist yet");
|
||||
if (verifyIndicesPrimaryShardsAreActive(clusterState, indexNameExpressionResolver) == false) {
|
||||
try {
|
||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||
MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener);
|
||||
listener.actionGet();
|
||||
verifiedStatsIndexCreated = true;
|
||||
logger.info("Created stats index");
|
||||
} catch (Exception e) {
|
||||
logger.error("failure creating ml stats index for storing model stats", e);
|
||||
return;
|
||||
logger.debug("About to create the stats index as it does not exist yet");
|
||||
createStatsIndexIfNecessary();
|
||||
} catch(Exception e){
|
||||
// This exception occurs if, for some reason, the `createStatsIndexAndAliasIfNecessary` fails due to
|
||||
// a concrete index of the alias name already existing. This error is recoverable eventually, but
|
||||
// should NOT cause us to lose statistics.
|
||||
if ((e instanceof InvalidAliasNameException) == false) {
|
||||
logger.error("failure creating ml stats index for storing model stats", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
List<InferenceStats> stats = new ArrayList<>(statsQueue.size());
|
||||
for(String k : statsQueue.keySet()) {
|
||||
// We want a copy as the underlying concurrent map could be changed while iterating
|
||||
// We don't want to accidentally grab updates twice
|
||||
Set<String> keys = new HashSet<>(statsQueue.keySet());
|
||||
for(String k : keys) {
|
||||
InferenceStats inferenceStats = statsQueue.remove(k);
|
||||
if (inferenceStats != null && inferenceStats.hasStats()) {
|
||||
if (inferenceStats != null) {
|
||||
stats.add(inferenceStats);
|
||||
}
|
||||
}
|
||||
@ -157,12 +177,46 @@ public class TrainedModelStatsService {
|
||||
if (bulkRequest.requests().isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (stopped) {
|
||||
return;
|
||||
}
|
||||
resultsPersisterService.bulkIndexWithRetry(bulkRequest,
|
||||
stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(",")),
|
||||
() -> stopped == false,
|
||||
(msg) -> {});
|
||||
}
|
||||
|
||||
private static boolean verifyIndicesPrimaryShardsAreActive(ClusterState clusterState, IndexNameExpressionResolver expressionResolver) {
|
||||
String[] indices = expressionResolver.concreteIndexNames(clusterState,
|
||||
IndicesOptions.LENIENT_EXPAND_OPEN_HIDDEN,
|
||||
MlStatsIndex.writeAlias());
|
||||
for (String index : indices) {
|
||||
if (clusterState.metadata().hasIndex(index) == false) {
|
||||
return false;
|
||||
}
|
||||
IndexRoutingTable routingTable = clusterState.getRoutingTable().index(index);
|
||||
if (routingTable == null || routingTable.allPrimaryShardsActive() == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void createStatsIndexIfNecessary() {
|
||||
PlainActionFuture<Boolean> listener = new PlainActionFuture<>();
|
||||
MlStatsIndex.createStatsIndexAndAliasIfNecessary(client, clusterState, indexNameExpressionResolver, listener);
|
||||
listener.actionGet();
|
||||
listener = new PlainActionFuture<>();
|
||||
ElasticsearchMappings.addDocMappingIfMissing(
|
||||
MlStatsIndex.writeAlias(),
|
||||
MlStatsIndex::mapping,
|
||||
client,
|
||||
clusterState,
|
||||
listener);
|
||||
listener.actionGet();
|
||||
logger.debug("Created stats index");
|
||||
}
|
||||
|
||||
static UpdateRequest buildUpdateRequest(InferenceStats stats) {
|
||||
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
@ -174,6 +228,9 @@ public class TrainedModelStatsService {
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
updateRequest.upsert(builder)
|
||||
.index(MlStatsIndex.writeAlias())
|
||||
// Usually, there shouldn't be a conflict, but if there is, only around a single update should have happened
|
||||
// out of band. If there is MANY more than that, something strange is happening and it should fail.
|
||||
.retryOnConflict(3)
|
||||
.id(InferenceStats.docId(stats.getModelId(), stats.getNodeId()))
|
||||
.script(new Script(ScriptType.INLINE, "painless", STATS_UPDATE_SCRIPT, params));
|
||||
return updateRequest;
|
||||
|
@ -24,7 +24,6 @@ import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
|
||||
@ -36,7 +35,7 @@ public class LocalModel implements Model {
|
||||
private final String nodeId;
|
||||
private final Set<String> fieldNames;
|
||||
private final Map<String, String> defaultFieldMap;
|
||||
private final AtomicReference<InferenceStats.Accumulator> statsAccumulator;
|
||||
private final InferenceStats.Accumulator statsAccumulator;
|
||||
private final TrainedModelStatsService trainedModelStatsService;
|
||||
private volatile long persistenceQuotient = 100;
|
||||
private final LongAdder currentInferenceCount;
|
||||
@ -53,7 +52,7 @@ public class LocalModel implements Model {
|
||||
this.modelId = modelId;
|
||||
this.nodeId = nodeId;
|
||||
this.fieldNames = new HashSet<>(input.getFieldNames());
|
||||
this.statsAccumulator = new AtomicReference<>(new InferenceStats.Accumulator(modelId, nodeId));
|
||||
this.statsAccumulator = new InferenceStats.Accumulator(modelId, nodeId);
|
||||
this.trainedModelStatsService = trainedModelStatsService;
|
||||
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
|
||||
this.currentInferenceCount = new LongAdder();
|
||||
@ -71,8 +70,7 @@ public class LocalModel implements Model {
|
||||
|
||||
@Override
|
||||
public InferenceStats getLatestStatsAndReset() {
|
||||
InferenceStats.Accumulator toPersist = statsAccumulator.getAndSet(new InferenceStats.Accumulator(modelId, nodeId));
|
||||
return toPersist.currentStats();
|
||||
return statsAccumulator.currentStatsAndReset();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -89,8 +87,8 @@ public class LocalModel implements Model {
|
||||
}
|
||||
}
|
||||
|
||||
void persistStats() {
|
||||
trainedModelStatsService.queueStats(getLatestStatsAndReset());
|
||||
void persistStats(boolean flush) {
|
||||
trainedModelStatsService.queueStats(getLatestStatsAndReset(), flush);
|
||||
if (persistenceQuotient < 1000 && currentInferenceCount.sum() > 1000) {
|
||||
persistenceQuotient = 1000;
|
||||
}
|
||||
@ -110,27 +108,27 @@ public class LocalModel implements Model {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incInference);
|
||||
statsAccumulator.incInference();
|
||||
currentInferenceCount.increment();
|
||||
|
||||
Model.mapFieldsIfNecessary(fields, defaultFieldMap);
|
||||
|
||||
boolean shouldPersistStats = ((currentInferenceCount.sum() + 1) % persistenceQuotient == 0);
|
||||
if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
|
||||
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incMissingFields);
|
||||
statsAccumulator.incMissingFields();
|
||||
if (shouldPersistStats) {
|
||||
persistStats();
|
||||
persistStats(false);
|
||||
}
|
||||
listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
|
||||
return;
|
||||
}
|
||||
InferenceResults inferenceResults = trainedModelDefinition.infer(fields, update.apply(inferenceConfig));
|
||||
if (shouldPersistStats) {
|
||||
persistStats();
|
||||
persistStats(false);
|
||||
}
|
||||
listener.onResponse(inferenceResults);
|
||||
} catch (Exception e) {
|
||||
statsAccumulator.updateAndGet(InferenceStats.Accumulator::incFailure);
|
||||
statsAccumulator.incFailure();
|
||||
listener.onFailure(e);
|
||||
}
|
||||
}
|
||||
|
@ -275,7 +275,8 @@ public class ModelLoadingService implements ClusterStateListener {
|
||||
INFERENCE_MODEL_CACHE_TTL.getKey());
|
||||
auditIfNecessary(notification.getKey(), msg);
|
||||
}
|
||||
notification.getValue().persistStats();
|
||||
// If the model is no longer referenced, flush the stats to persist as soon as possible
|
||||
notification.getValue().persistStats(referencedModels.contains(notification.getKey()) == false);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -30,7 +30,7 @@ public class RestGetTrainedModelsStatsAction extends BaseRestHandler {
|
||||
public List<Route> routes() {
|
||||
return unmodifiableList(asList(
|
||||
new Route(GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats"),
|
||||
new Route(GET, MachineLearning.BASE_PATH + MachineLearning.BASE_PATH + "inference/_stats")));
|
||||
new Route(GET, MachineLearning.BASE_PATH + "inference/_stats")));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -45,6 +45,7 @@ import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyBoolean;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
@ -55,7 +56,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testClassificationInfer() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
@ -126,7 +127,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testClassificationInferWithDifferentPredictionFieldTypes() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo.keyword", "field.bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
@ -183,7 +184,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testRegression() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
@ -209,7 +210,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testAllFieldsMissing() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
|
||||
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
|
||||
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
|
||||
@ -238,7 +239,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
|
||||
public void testInferPersistsStatsAfterNumberOfCalls() throws Exception {
|
||||
TrainedModelStatsService modelStatsService = mock(TrainedModelStatsService.class);
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class));
|
||||
doAnswer((args) -> null).when(modelStatsService).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
String modelId = "classification_model";
|
||||
List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
|
||||
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
|
||||
@ -273,7 +274,7 @@ public class LocalModelTests extends ESTestCase {
|
||||
public boolean matches(Object o) {
|
||||
return ((InferenceStats)o).getInferenceCount() == 99L;
|
||||
}
|
||||
}));
|
||||
}), anyBoolean());
|
||||
}
|
||||
|
||||
private static <T extends InferenceConfig> SingleValueInferenceResults getSingleValue(Model model,
|
||||
|
@ -61,6 +61,7 @@ import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyBoolean;
|
||||
import static org.mockito.Matchers.argThat;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.atMost;
|
||||
@ -194,7 +195,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model3);
|
||||
}
|
||||
}));
|
||||
}), anyBoolean());
|
||||
|
||||
// Load model 3, should invalidate 1 and 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
@ -209,13 +210,13 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model1);
|
||||
}
|
||||
}));
|
||||
}), anyBoolean());
|
||||
verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<InferenceStats>() {
|
||||
@Override
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model2);
|
||||
}
|
||||
}));
|
||||
}), anyBoolean());
|
||||
|
||||
// Load model 1, should invalidate 3
|
||||
for(int i = 0; i < 10; i++) {
|
||||
@ -229,7 +230,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||
public boolean matches(final Object o) {
|
||||
return ((InferenceStats)o).getModelId().equals(model3);
|
||||
}
|
||||
}));
|
||||
}), anyBoolean());
|
||||
|
||||
// Load model 2
|
||||
for(int i = 0; i < 10; i++) {
|
||||
@ -278,7 +279,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
||||
public void testGetCachedMissingModel() throws Exception {
|
||||
@ -306,7 +307,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
||||
public void testGetMissingModel() {
|
||||
@ -352,7 +353,7 @@ public class ModelLoadingServiceTests extends ESTestCase {
|
||||
}
|
||||
|
||||
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class));
|
||||
verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
Loading…
x
Reference in New Issue
Block a user