diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 3d3d280954b..74ab339f764 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -13,6 +13,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetaData; @@ -60,6 +61,7 @@ import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.ml.MachineLearningField; @@ -524,9 +526,11 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor = new DataFrameAnalyticsAuditor(client, clusterService.getNodeName()); InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService.getNodeName()); this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor); - ResultsPersisterService resultsPersisterService = new ResultsPersisterService(client, clusterService, settings); + OriginSettingClient originSettingClient = new OriginSettingClient(client, ClientHelper.ML_ORIGIN); + ResultsPersisterService resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings); JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings); - JobResultsPersister jobResultsPersister = new JobResultsPersister(client, resultsPersisterService, anomalyDetectionAuditor); + JobResultsPersister jobResultsPersister = + new JobResultsPersister(originSettingClient, resultsPersisterService, anomalyDetectionAuditor); JobDataCountsPersister jobDataCountsPersister = new JobDataCountsPersister(client, resultsPersisterService, anomalyDetectionAuditor); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersister.java index b9a3fbaa570..b156840aa64 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersister.java @@ -17,12 +17,17 @@ import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.IdsQueryBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedTimingStats; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; @@ -72,11 +77,11 @@ public class JobResultsPersister { private static final Logger logger = LogManager.getLogger(JobResultsPersister.class); - private final Client client; + private final OriginSettingClient client; private final ResultsPersisterService resultsPersisterService; private final AnomalyDetectionAuditor auditor; - public JobResultsPersister(Client client, + public JobResultsPersister(OriginSettingClient client, ResultsPersisterService resultsPersisterService, AnomalyDetectionAuditor auditor) { this.client = client; @@ -244,9 +249,9 @@ public class JobResultsPersister { * @param category The category to be persisted */ public void persistCategoryDefinition(CategoryDefinition category, Supplier shouldRetry) { - Persistable persistable = new Persistable(category.getJobId(), category, category.getId()); - - persistable.persist(AnomalyDetectorsIndex.resultsWriteAlias(category.getJobId()), shouldRetry); + Persistable persistable = + new Persistable(AnomalyDetectorsIndex.resultsWriteAlias(category.getJobId()), category.getJobId(), category, category.getId()); + persistable.persist(shouldRetry); // Don't commit as we expect masses of these updates and they're not // read again by this process } @@ -255,17 +260,61 @@ public class JobResultsPersister { * Persist the quantiles (blocking) */ public void persistQuantiles(Quantiles quantiles, Supplier shouldRetry) { - Persistable persistable = new Persistable(quantiles.getJobId(), quantiles, Quantiles.documentId(quantiles.getJobId())); - persistable.persist(AnomalyDetectorsIndex.jobStateIndexWriteAlias(), shouldRetry); + String jobId = quantiles.getJobId(); + String quantilesDocId = Quantiles.documentId(jobId); + SearchRequest searchRequest = buildQuantilesDocIdSearch(quantilesDocId); + SearchResponse searchResponse = + resultsPersisterService.searchWithRetry( + searchRequest, + jobId, + shouldRetry, + (msg) -> auditor.warning(jobId, quantilesDocId + " " + msg)); + String indexOrAlias = + searchResponse.getHits().getHits().length > 0 + ? searchResponse.getHits().getHits()[0].getIndex() + : AnomalyDetectorsIndex.jobStateIndexWriteAlias(); + + Persistable persistable = new Persistable(indexOrAlias, quantiles.getJobId(), quantiles, quantilesDocId); + persistable.persist(shouldRetry); } /** * Persist the quantiles (async) */ public void persistQuantiles(Quantiles quantiles, WriteRequest.RefreshPolicy refreshPolicy, ActionListener listener) { - Persistable persistable = new Persistable(quantiles.getJobId(), quantiles, Quantiles.documentId(quantiles.getJobId())); - persistable.setRefreshPolicy(refreshPolicy); - persistable.persist(AnomalyDetectorsIndex.jobStateIndexWriteAlias(), listener); + String quantilesDocId = Quantiles.documentId(quantiles.getJobId()); + + // Step 2: Create or update the quantiles document: + // - if the document did not exist, create the new one in the current write index + // - if the document did exist, update it in the index where it resides (not necessarily the current write index) + ActionListener searchFormerQuantilesDocListener = ActionListener.wrap( + searchResponse -> { + String indexOrAlias = + searchResponse.getHits().getHits().length > 0 + ? searchResponse.getHits().getHits()[0].getIndex() + : AnomalyDetectorsIndex.jobStateIndexWriteAlias(); + + Persistable persistable = new Persistable(indexOrAlias, quantiles.getJobId(), quantiles, quantilesDocId); + persistable.setRefreshPolicy(refreshPolicy); + persistable.persist(listener); + }, + listener::onFailure + ); + + // Step 1: Search for existing quantiles document in .ml-state* + SearchRequest searchRequest = buildQuantilesDocIdSearch(quantilesDocId); + executeAsyncWithOrigin( + client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest, searchFormerQuantilesDocListener, client::search); + } + + private static SearchRequest buildQuantilesDocIdSearch(String quantilesDocId) { + return new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern()) + .allowPartialSearchResults(false) + .source( + new SearchSourceBuilder() + .size(1) + .trackTotalHits(false) + .query(new BoolQueryBuilder().filter(new IdsQueryBuilder().addIds(quantilesDocId)))); } /** @@ -274,9 +323,14 @@ public class JobResultsPersister { public BulkResponse persistModelSnapshot(ModelSnapshot modelSnapshot, WriteRequest.RefreshPolicy refreshPolicy, Supplier shouldRetry) { - Persistable persistable = new Persistable(modelSnapshot.getJobId(), modelSnapshot, ModelSnapshot.documentId(modelSnapshot)); + Persistable persistable = + new Persistable( + AnomalyDetectorsIndex.resultsWriteAlias(modelSnapshot.getJobId()), + modelSnapshot.getJobId(), + modelSnapshot, + ModelSnapshot.documentId(modelSnapshot)); persistable.setRefreshPolicy(refreshPolicy); - return persistable.persist(AnomalyDetectorsIndex.resultsWriteAlias(modelSnapshot.getJobId()), shouldRetry); + return persistable.persist(shouldRetry); } /** @@ -285,8 +339,9 @@ public class JobResultsPersister { public void persistModelSizeStats(ModelSizeStats modelSizeStats, Supplier shouldRetry) { String jobId = modelSizeStats.getJobId(); logger.trace("[{}] Persisting model size stats, for size {}", jobId, modelSizeStats.getModelBytes()); - Persistable persistable = new Persistable(jobId, modelSizeStats, modelSizeStats.getId()); - persistable.persist(AnomalyDetectorsIndex.resultsWriteAlias(jobId), shouldRetry); + Persistable persistable = + new Persistable(AnomalyDetectorsIndex.resultsWriteAlias(jobId), jobId, modelSizeStats, modelSizeStats.getId()); + persistable.persist(shouldRetry); } /** @@ -296,9 +351,10 @@ public class JobResultsPersister { ActionListener listener) { String jobId = modelSizeStats.getJobId(); logger.trace("[{}] Persisting model size stats, for size {}", jobId, modelSizeStats.getModelBytes()); - Persistable persistable = new Persistable(jobId, modelSizeStats, modelSizeStats.getId()); + Persistable persistable = + new Persistable(AnomalyDetectorsIndex.resultsWriteAlias(jobId), jobId, modelSizeStats, modelSizeStats.getId()); persistable.setRefreshPolicy(refreshPolicy); - persistable.persist(AnomalyDetectorsIndex.resultsWriteAlias(jobId), listener); + persistable.persist(listener); } /** @@ -354,13 +410,15 @@ public class JobResultsPersister { public BulkResponse persistDatafeedTimingStats(DatafeedTimingStats timingStats, WriteRequest.RefreshPolicy refreshPolicy) { String jobId = timingStats.getJobId(); logger.trace("[{}] Persisting datafeed timing stats", jobId); - Persistable persistable = new Persistable( - jobId, - timingStats, - new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), - DatafeedTimingStats.documentId(timingStats.getJobId())); + Persistable persistable = + new Persistable( + AnomalyDetectorsIndex.resultsWriteAlias(jobId), + jobId, + timingStats, + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), + DatafeedTimingStats.documentId(timingStats.getJobId())); persistable.setRefreshPolicy(refreshPolicy); - return persistable.persist(AnomalyDetectorsIndex.resultsWriteAlias(jobId), () -> true); + return persistable.persist(() -> true); } private static XContentBuilder toXContentBuilder(ToXContent obj, ToXContent.Params params) throws IOException { @@ -371,17 +429,19 @@ public class JobResultsPersister { private class Persistable { + private final String indexName; private final String jobId; private final ToXContent object; private final ToXContent.Params params; private final String id; private WriteRequest.RefreshPolicy refreshPolicy; - Persistable(String jobId, ToXContent object, String id) { - this(jobId, object, ToXContent.EMPTY_PARAMS, id); + Persistable(String indexName, String jobId, ToXContent object, String id) { + this(indexName, jobId, object, ToXContent.EMPTY_PARAMS, id); } - Persistable(String jobId, ToXContent object, ToXContent.Params params, String id) { + Persistable(String indexName, String jobId, ToXContent object, ToXContent.Params params, String id) { + this.indexName = indexName; this.jobId = jobId; this.object = object; this.params = params; @@ -393,7 +453,7 @@ public class JobResultsPersister { this.refreshPolicy = refreshPolicy; } - BulkResponse persist(String indexName, Supplier shouldRetry) { + BulkResponse persist(Supplier shouldRetry) { logCall(indexName); try { return resultsPersisterService.indexWithRetry(jobId, @@ -414,7 +474,7 @@ public class JobResultsPersister { } } - void persist(String indexName, ActionListener listener) { + void persist(ActionListener listener) { logCall(indexName); try (XContentBuilder content = toXContentBuilder(object, params)) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java index 09224d99399..a775342b880 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java @@ -13,17 +13,20 @@ import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.rest.RestStatus; import java.io.IOException; import java.time.Duration; @@ -34,9 +37,6 @@ import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; - public class ResultsPersisterService { private static final Logger LOGGER = LogManager.getLogger(ResultsPersisterService.class); @@ -52,10 +52,20 @@ public class ResultsPersisterService { // Having an exponent higher than this causes integer overflow private static final int MAX_RETRY_EXPONENT = 24; - private final Client client; + private final CheckedConsumer sleeper; + private final OriginSettingClient client; private volatile int maxFailureRetries; - public ResultsPersisterService(Client client, ClusterService clusterService, Settings settings) { + public ResultsPersisterService(OriginSettingClient client, ClusterService clusterService, Settings settings) { + this(Thread::sleep, client, clusterService, settings); + } + + // Visible for testing + ResultsPersisterService(CheckedConsumer sleeper, + OriginSettingClient client, + ClusterService clusterService, + Settings settings) { + this.sleeper = sleeper; this.client = client; this.maxFailureRetries = PERSIST_RESULTS_MAX_RETRIES.get(settings); clusterService.getClusterSettings() @@ -85,29 +95,84 @@ public class ResultsPersisterService { String jobId, Supplier shouldRetry, Consumer msgHandler) { - int currentMin = MIN_RETRY_SLEEP_MILLIS; - int currentMax = MIN_RETRY_SLEEP_MILLIS; - int currentAttempt = 0; - BulkResponse bulkResponse = null; - final Random random = Randomness.get(); - while(currentAttempt <= maxFailureRetries) { - bulkResponse = bulkIndex(bulkRequest); + RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler); + while (true) { + BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet(); if (bulkResponse.hasFailures() == false) { return bulkResponse; } - if (shouldRetry.get() == false) { - throw new ElasticsearchException("[{}] failed to index all results. {}", jobId, bulkResponse.buildFailureMessage()); - } - if (currentAttempt > maxFailureRetries) { - LOGGER.warn("[{}] failed to index after [{}] attempts. Setting [xpack.ml.persist_results_max_retries] was reduced", - jobId, - currentAttempt); - throw new ElasticsearchException("[{}] failed to index all results after [{}] attempts. {}", - jobId, - currentAttempt, - bulkResponse.buildFailureMessage()); + + retryContext.nextIteration("index", bulkResponse.buildFailureMessage()); + + // We should only retry the docs that failed. + bulkRequest = buildNewRequestFromFailures(bulkRequest, bulkResponse); + } + } + + public SearchResponse searchWithRetry(SearchRequest searchRequest, + String jobId, + Supplier shouldRetry, + Consumer msgHandler) { + RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler); + while (true) { + String failureMessage; + try { + SearchResponse searchResponse = client.search(searchRequest).actionGet(); + if (RestStatus.OK.equals(searchResponse.status())) { + return searchResponse; + } + failureMessage = searchResponse.status().toString(); + } catch (ElasticsearchException e) { + LOGGER.warn("[" + jobId + "] Exception while executing search action", e); + failureMessage = e.getDetailedMessage(); } + + retryContext.nextIteration("search", failureMessage); + } + } + + /** + * {@link RetryContext} object handles logic that is executed between consecutive retries of an action. + * + * Note that it does not execute the action itself. + */ + private class RetryContext { + + final String jobId; + final Supplier shouldRetry; + final Consumer msgHandler; + final Random random = Randomness.get(); + + int currentAttempt = 0; + int currentMin = MIN_RETRY_SLEEP_MILLIS; + int currentMax = MIN_RETRY_SLEEP_MILLIS; + + RetryContext(String jobId, Supplier shouldRetry, Consumer msgHandler) { + this.jobId = jobId; + this.shouldRetry = shouldRetry; + this.msgHandler = msgHandler; + } + + void nextIteration(String actionName, String failureMessage) { currentAttempt++; + + // If the outside conditions have changed and retries are no longer needed, do not retry. + if (shouldRetry.get() == false) { + String msg = new ParameterizedMessage( + "[{}] should not retry {} after [{}] attempts. {}", jobId, actionName, currentAttempt, failureMessage) + .getFormattedMessage(); + LOGGER.info(msg); + throw new ElasticsearchException(msg); + } + + // If the configured maximum number of retries has been reached, do not retry. + if (currentAttempt > maxFailureRetries) { + String msg = new ParameterizedMessage( + "[{}] failed to {} after [{}] attempts. {}", jobId, actionName, currentAttempt, failureMessage).getFormattedMessage(); + LOGGER.warn(msg); + throw new ElasticsearchException(msg); + } + // Since we exponentially increase, we don't want force randomness to have an excessively long sleep if (currentMax < MAX_RETRY_SLEEP_MILLIS) { currentMin = currentMax; @@ -121,38 +186,26 @@ public class ResultsPersisterService { int randSleep = currentMin + random.nextInt(randBound); { String msg = new ParameterizedMessage( - "failed to index after [{}] attempts. Will attempt again in [{}].", + "failed to {} after [{}] attempts. Will attempt again in [{}].", + actionName, currentAttempt, TimeValue.timeValueMillis(randSleep).getStringRep()) .getFormattedMessage(); - LOGGER.warn(()-> new ParameterizedMessage("[{}] {}", jobId, msg)); + LOGGER.warn(() -> new ParameterizedMessage("[{}] {}", jobId, msg)); msgHandler.accept(msg); } - // We should only retry the docs that failed. - bulkRequest = buildNewRequestFromFailures(bulkRequest, bulkResponse); try { - Thread.sleep(randSleep); + sleeper.accept(randSleep); } catch (InterruptedException interruptedException) { LOGGER.warn( - new ParameterizedMessage("[{}] failed to index after [{}] attempts due to interruption", + new ParameterizedMessage("[{}] failed to {} after [{}] attempts due to interruption", jobId, + actionName, currentAttempt), interruptedException); Thread.currentThread().interrupt(); } } - String bulkFailureMessage = bulkResponse == null ? "" : bulkResponse.buildFailureMessage(); - LOGGER.warn("[{}] failed to index after [{}] attempts.", jobId, currentAttempt); - throw new ElasticsearchException("[{}] failed to index all results after [{}] attempts. {}", - jobId, - currentAttempt, - bulkFailureMessage); - } - - private BulkResponse bulkIndex(BulkRequest bulkRequest) { - try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { - return client.bulk(bulkRequest).actionGet(); - } } private BulkRequest buildNewRequestFromFailures(BulkRequest bulkRequest, BulkResponse bulkResponse) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java index 060b1c88463..0bd08bb3603 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.routing.OperationRouting; import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider; @@ -20,6 +21,7 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.index.reindex.ReindexPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.action.util.QueryPage; @@ -115,13 +117,14 @@ public class AutodetectResultProcessorIT extends MlSingleNodeTestCase { ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING))); ClusterService clusterService = new ClusterService(settings, clusterSettings, tp); - resultsPersisterService = new ResultsPersisterService(client(), clusterService, settings); + OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN); + resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings); resultProcessor = new AutodetectResultProcessor( client(), auditor, JOB_ID, renormalizer, - new JobResultsPersister(client(), resultsPersisterService, new AnomalyDetectionAuditor(client(), "test_node")), + new JobResultsPersister(originSettingClient, resultsPersisterService, new AnomalyDetectionAuditor(client(), "test_node")), process, new ModelSizeStats.Builder(JOB_ID).build(), new TimingStats(JOB_ID)) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/EstablishedMemUsageIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/EstablishedMemUsageIT.java index dcbb6c2bc3f..6ae60f714be 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/EstablishedMemUsageIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/EstablishedMemUsageIT.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.routing.OperationRouting; import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider; import org.elasticsearch.cluster.service.ClusterApplierService; @@ -13,6 +14,7 @@ import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.Job; @@ -56,11 +58,11 @@ public class EstablishedMemUsageIT extends BaseMlIntegTestCase { ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING))); ClusterService clusterService = new ClusterService(settings, clusterSettings, tp); - ResultsPersisterService resultsPersisterService = new ResultsPersisterService(client(), clusterService, settings); + OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN); + ResultsPersisterService resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings); jobResultsProvider = new JobResultsProvider(client(), settings); - jobResultsPersister = new JobResultsPersister(client(), - resultsPersisterService, - new AnomalyDetectionAuditor(client(), "test_node")); + jobResultsPersister = new JobResultsPersister( + originSettingClient, resultsPersisterService, new AnomalyDetectionAuditor(client(), "test_node")); } public void testEstablishedMem_givenNoResults() throws Exception { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java index e4ef681eca3..c7b7faa8175 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.metadata.MappingMetaData; import org.elasticsearch.cluster.routing.OperationRouting; import org.elasticsearch.cluster.routing.UnassignedInfo; @@ -30,6 +31,7 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MlMetaIndex; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.PutJobAction; @@ -108,7 +110,8 @@ public class JobResultsProviderIT extends MlSingleNodeTestCase { ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING))); ClusterService clusterService = new ClusterService(builder.build(), clusterSettings, tp); - resultsPersisterService = new ResultsPersisterService(client(), clusterService, builder.build()); + OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN); + resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, builder.build()); auditor = new AnomalyDetectionAuditor(client(), "test_node"); waitForMlTemplates(); } @@ -623,17 +626,20 @@ public class JobResultsProviderIT extends MlSingleNodeTestCase { } private void indexModelSizeStats(ModelSizeStats modelSizeStats) { - JobResultsPersister persister = new JobResultsPersister(client(), resultsPersisterService, auditor); + JobResultsPersister persister = + new JobResultsPersister(new OriginSettingClient(client(), ClientHelper.ML_ORIGIN), resultsPersisterService, auditor); persister.persistModelSizeStats(modelSizeStats, () -> true); } private void indexModelSnapshot(ModelSnapshot snapshot) { - JobResultsPersister persister = new JobResultsPersister(client(), resultsPersisterService, auditor); + JobResultsPersister persister = + new JobResultsPersister(new OriginSettingClient(client(), ClientHelper.ML_ORIGIN), resultsPersisterService, auditor); persister.persistModelSnapshot(snapshot, WriteRequest.RefreshPolicy.IMMEDIATE, () -> true); } private void indexQuantiles(Quantiles quantiles) { - JobResultsPersister persister = new JobResultsPersister(client(), resultsPersisterService, auditor); + JobResultsPersister persister = + new JobResultsPersister(new OriginSettingClient(client(), ClientHelper.ML_ORIGIN), resultsPersisterService, auditor); persister.persistQuantiles(quantiles, () -> true); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java index 62e5b51ab49..19eb64b7f4e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java @@ -5,13 +5,18 @@ */ package org.elasticsearch.xpack.ml.job.persistence; -import org.elasticsearch.action.ActionFuture; -import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.routing.OperationRouting; import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider; import org.elasticsearch.cluster.service.ClusterApplierService; @@ -19,10 +24,14 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedTimingStats; +import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.Quantiles; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.TimingStats; import org.elasticsearch.xpack.core.ml.job.results.AnomalyRecord; import org.elasticsearch.xpack.core.ml.job.results.Bucket; @@ -32,8 +41,12 @@ import org.elasticsearch.xpack.core.ml.job.results.ModelPlot; import org.elasticsearch.xpack.core.ml.utils.ExponentialAverageCalculationContext; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; +import org.elasticsearch.xpack.ml.test.MockOriginSettingClient; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; +import org.junit.Before; import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.stubbing.Answer; import java.time.Instant; import java.util.ArrayList; @@ -47,7 +60,10 @@ import java.util.Map; import static org.hamcrest.Matchers.equalTo; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -59,9 +75,21 @@ public class JobResultsPersisterTests extends ESTestCase { private static final String JOB_ID = "foo"; + private Client client; + private OriginSettingClient originSettingClient; + private ArgumentCaptor bulkRequestCaptor; + private JobResultsPersister persister; + + @Before + public void setUpTests() { + bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + client = mock(Client.class); + doAnswer(withResponse(mock(BulkResponse.class))).when(client).execute(eq(BulkAction.INSTANCE), any(), any()); + originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN); + persister = new JobResultsPersister(originSettingClient, buildResultsPersisterService(originSettingClient), makeAuditor()); + } + public void testPersistBucket_OneRecord() { - ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - Client client = mockClient(captor); Bucket bucket = new Bucket("foo", new Date(), 123456); bucket.setAnomalyScore(99.9); bucket.setEventCount(57); @@ -80,9 +108,11 @@ public class JobResultsPersisterTests extends ESTestCase { AnomalyRecord record = new AnomalyRecord(JOB_ID, new Date(), 600); bucket.setRecords(Collections.singletonList(record)); - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); persister.bulkPersisterBuilder(JOB_ID, () -> true).persistBucket(bucket).executeRequest(); - BulkRequest bulkRequest = captor.getValue(); + + verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); + + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(2, bulkRequest.numberOfActions()); String s = ((IndexRequest)bulkRequest.requests().get(0)).source().utf8ToString(); @@ -103,9 +133,6 @@ public class JobResultsPersisterTests extends ESTestCase { } public void testPersistRecords() { - ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - Client client = mockClient(captor); - List records = new ArrayList<>(); AnomalyRecord r1 = new AnomalyRecord(JOB_ID, new Date(), 42); records.add(r1); @@ -132,9 +159,11 @@ public class JobResultsPersisterTests extends ESTestCase { typicals.add(998765.3); r1.setTypical(typicals); - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); persister.bulkPersisterBuilder(JOB_ID, () -> true).persistRecords(records).executeRequest(); - BulkRequest bulkRequest = captor.getValue(); + + verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); + + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(1, bulkRequest.numberOfActions()); String s = ((IndexRequest) bulkRequest.requests().get(0)).source().utf8ToString(); @@ -158,9 +187,6 @@ public class JobResultsPersisterTests extends ESTestCase { } public void testPersistInfluencers() { - ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - Client client = mockClient(captor); - List influencers = new ArrayList<>(); Influencer inf = new Influencer(JOB_ID, "infName1", "infValue1", new Date(), 600); inf.setInfluencerScore(16); @@ -168,9 +194,11 @@ public class JobResultsPersisterTests extends ESTestCase { inf.setProbability(0.4); influencers.add(inf); - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); persister.bulkPersisterBuilder(JOB_ID, () -> true).persistInfluencers(influencers).executeRequest(); - BulkRequest bulkRequest = captor.getValue(); + + verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); + + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertEquals(1, bulkRequest.numberOfActions()); String s = ((IndexRequest) bulkRequest.requests().get(0)).source().utf8ToString(); @@ -182,10 +210,6 @@ public class JobResultsPersisterTests extends ESTestCase { } public void testExecuteRequest_ClearsBulkRequest() { - ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - Client client = mockClient(captor); - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); - List influencers = new ArrayList<>(); Influencer inf = new Influencer(JOB_ID, "infName1", "infValue1", new Date(), 600); inf.setInfluencerScore(16); @@ -199,32 +223,31 @@ public class JobResultsPersisterTests extends ESTestCase { } public void testBulkRequestExecutesWhenReachMaxDocs() { - ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - Client client = mockClient(captor); - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); - JobResultsPersister.Builder bulkBuilder = persister.bulkPersisterBuilder("foo", () -> true); ModelPlot modelPlot = new ModelPlot("foo", new Date(), 123456, 0); for (int i=0; i<=JobRenormalizedResultsPersister.BULK_LIMIT; i++) { bulkBuilder.persistModelPlot(modelPlot); } - verify(client, times(1)).bulk(any()); - verify(client, times(1)).threadPool(); + InOrder inOrder = inOrder(client); + inOrder.verify(client).settings(); + inOrder.verify(client, times(3)).threadPool(); + inOrder.verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); verifyNoMoreInteractions(client); } public void testPersistTimingStats() { - ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); - Client client = mockClient(bulkRequestCaptor); - - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); TimingStats timingStats = new TimingStats( "foo", 7, 1.0, 2.0, 1.23, 7.89, new ExponentialAverageCalculationContext(600.0, Instant.ofEpochMilli(123456789), 60.0)); persister.bulkPersisterBuilder(JOB_ID, () -> true).persistTimingStats(timingStats).executeRequest(); - verify(client, times(1)).bulk(bulkRequestCaptor.capture()); + InOrder inOrder = inOrder(client); + inOrder.verify(client).settings(); + inOrder.verify(client, times(3)).threadPool(); + inOrder.verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); + verifyNoMoreInteractions(client); + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); assertThat(bulkRequest.requests().size(), equalTo(1)); IndexRequest indexRequest = (IndexRequest) bulkRequest.requests().get(0); @@ -244,26 +267,24 @@ public class JobResultsPersisterTests extends ESTestCase { calculationContextMap.put("latest_timestamp", 123456789); expectedSourceAsMap.put("exponential_average_calculation_context", calculationContextMap); assertThat(indexRequest.sourceAsMap(), equalTo(expectedSourceAsMap)); - - verify(client, times(1)).threadPool(); - verifyNoMoreInteractions(client); } - @SuppressWarnings({"unchecked", "rawtypes"}) + @SuppressWarnings("unchecked") public void testPersistDatafeedTimingStats() { - Client client = mockClient(ArgumentCaptor.forClass(BulkRequest.class)); - JobResultsPersister persister = new JobResultsPersister(client, buildResultsPersisterService(client), makeAuditor()); DatafeedTimingStats timingStats = new DatafeedTimingStats( "foo", 6, 66, 666.0, new ExponentialAverageCalculationContext(600.0, Instant.ofEpochMilli(123456789), 60.0)); persister.persistDatafeedTimingStats(timingStats, WriteRequest.RefreshPolicy.IMMEDIATE); - ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(1)).bulk(indexRequestCaptor.capture()); + InOrder inOrder = inOrder(client); + inOrder.verify(client).settings(); + inOrder.verify(client, times(3)).threadPool(); + inOrder.verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); + verifyNoMoreInteractions(client); // Refresh policy is set on the bulk request, not the individual index requests - assertThat(indexRequestCaptor.getValue().getRefreshPolicy(), equalTo(WriteRequest.RefreshPolicy.IMMEDIATE)); - IndexRequest indexRequest = (IndexRequest)indexRequestCaptor.getValue().requests().get(0); + assertThat(bulkRequestCaptor.getValue().getRefreshPolicy(), equalTo(WriteRequest.RefreshPolicy.IMMEDIATE)); + IndexRequest indexRequest = (IndexRequest) bulkRequestCaptor.getValue().requests().get(0); assertThat(indexRequest.index(), equalTo(".ml-anomalies-.write-foo")); assertThat(indexRequest.id(), equalTo("foo_datafeed_timing_stats")); Map expectedSourceAsMap = new HashMap<>(); @@ -278,37 +299,88 @@ public class JobResultsPersisterTests extends ESTestCase { calculationContextMap.put("latest_timestamp", 123456789); expectedSourceAsMap.put("exponential_average_calculation_context", calculationContextMap); assertThat(indexRequest.sourceAsMap(), equalTo(expectedSourceAsMap)); - verify(client, times(1)).threadPool(); - verifyNoMoreInteractions(client); - } - - private Client mockClient(ArgumentCaptor captor) { - return mockClientWithResponse(captor, new BulkResponse(new BulkItemResponse[0], 0L)); - } - - @SuppressWarnings({"unchecked", "rawtypes"}) - private Client mockClientWithResponse(ArgumentCaptor captor, BulkResponse... responses) { - Client client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - List> futures = new ArrayList<>(responses.length - 1); - ActionFuture future1 = makeFuture(responses[0]); - for (int i = 1; i < responses.length; i++) { - futures.add(makeFuture(responses[i])); - } - when(client.bulk(captor.capture())).thenReturn(future1, futures.toArray(new ActionFuture[0])); - return client; } @SuppressWarnings("unchecked") - private static ActionFuture makeFuture(BulkResponse response) { - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); - return future; + private void testPersistQuantilesSync(SearchHits searchHits, String expectedIndexOrAlias) { + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getHits()).thenReturn(searchHits); + doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any()); + + Quantiles quantiles = new Quantiles("foo", new Date(), "bar"); + persister.persistQuantiles(quantiles, () -> false); + + InOrder inOrder = inOrder(client); + inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any()); + inOrder.verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any()); + inOrder.verifyNoMoreInteractions(); + + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); + assertThat(bulkRequest.requests().size(), equalTo(1)); + IndexRequest indexRequest = (IndexRequest) bulkRequest.requests().get(0); + + assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias)); + assertThat(indexRequest.id(), equalTo("foo_quantiles")); } - private ResultsPersisterService buildResultsPersisterService(Client client) { + public void testPersistQuantilesSync_QuantilesDocumentCreated() { + testPersistQuantilesSync(SearchHits.empty(), ".ml-state-write"); + } + + public void testPersistQuantilesSync_QuantilesDocumentUpdated() { + testPersistQuantilesSync( + new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f), + ".ml-state-dummy"); + } + + @SuppressWarnings("unchecked") + private void testPersistQuantilesAsync(SearchHits searchHits, String expectedIndexOrAlias) { + ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); + + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getHits()).thenReturn(searchHits); + doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any()); + + IndexResponse indexResponse = mock(IndexResponse.class); + doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any()); + + Quantiles quantiles = new Quantiles("foo", new Date(), "bar"); + ActionListener indexResponseListener = mock(ActionListener.class); + persister.persistQuantiles(quantiles, WriteRequest.RefreshPolicy.IMMEDIATE, indexResponseListener); + + InOrder inOrder = inOrder(client, indexResponseListener); + inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any()); + inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any()); + inOrder.verify(indexResponseListener).onResponse(any()); + inOrder.verifyNoMoreInteractions(); + + IndexRequest indexRequest = indexRequestCaptor.getValue(); + + assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias)); + assertThat(indexRequest.id(), equalTo("foo_quantiles")); + } + + public void testPersistQuantilesAsync_QuantilesDocumentCreated() { + testPersistQuantilesAsync(SearchHits.empty(), ".ml-state-write"); + } + + public void testPersistQuantilesAsync_QuantilesDocumentUpdated() { + testPersistQuantilesAsync( + new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f), + ".ml-state-dummy"); + } + + @SuppressWarnings("unchecked") + private static Answer withResponse(Response response) { + return invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }; + } + + private ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) { ThreadPool tp = mock(ThreadPool.class); ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java index acf4d3d2337..700f83fdc27 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java @@ -6,28 +6,40 @@ package org.elasticsearch.xpack.ml.utils.persistence; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; +import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.cluster.routing.OperationRouting; import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.MasterService; +import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; +import org.elasticsearch.xpack.ml.test.MockOriginSettingClient; +import org.junit.Before; import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; import java.util.ArrayList; import java.util.Arrays; @@ -35,163 +47,278 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; +import java.util.function.Supplier; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; public class ResultsPersisterServiceTests extends ESTestCase { - private final String JOB_ID = "results_persister_test_job"; - private final Consumer NULL_MSG_HANDLER = (msg) -> {}; + // Common constants + private static final String JOB_ID = "results_persister_test_job"; - public void testBulkRequestChangeOnFailures() { - IndexRequest indexRequestSuccess = new IndexRequest("my-index").id("success").source(Collections.singletonMap("data", "success")); - IndexRequest indexRequestFail = new IndexRequest("my-index").id("fail").source(Collections.singletonMap("data", "fail")); - BulkItemResponse successItem = new BulkItemResponse(1, + // Constants for searchWithRetry tests + private static final SearchRequest SEARCH_REQUEST = new SearchRequest("my-index"); + private static final SearchResponse SEARCH_RESPONSE_SUCCESS = + new SearchResponse(null, null, 1, 1, 0, 0, ShardSearchFailure.EMPTY_ARRAY, null); + private static final SearchResponse SEARCH_RESPONSE_FAILURE = + new SearchResponse(null, null, 1, 0, 0, 0, ShardSearchFailure.EMPTY_ARRAY, null); + + // Constants for bulkIndexWithRetry tests + private static final IndexRequest INDEX_REQUEST_SUCCESS = + new IndexRequest("my-index").id("success").source(Collections.singletonMap("data", "success")); + private static final IndexRequest INDEX_REQUEST_FAILURE = + new IndexRequest("my-index").id("fail").source(Collections.singletonMap("data", "fail")); + private static final BulkItemResponse BULK_ITEM_RESPONSE_SUCCESS = + new BulkItemResponse( + 1, DocWriteRequest.OpType.INDEX, new IndexResponse(new ShardId(AnomalyDetectorsIndex.jobResultsIndexPrefix() + "shared", "uuid", 1), "_doc", - indexRequestSuccess.id(), + INDEX_REQUEST_SUCCESS.id(), 0, 0, 1, true)); - BulkItemResponse failureItem = new BulkItemResponse(2, + private static final BulkItemResponse BULK_ITEM_RESPONSE_FAILURE = + new BulkItemResponse( + 2, DocWriteRequest.OpType.INDEX, new BulkItemResponse.Failure("my-index", "_doc", "fail", new Exception("boom"))); - BulkResponse withFailure = new BulkResponse(new BulkItemResponse[]{ failureItem, successItem }, 0L); - Client client = mockClientWithResponse(withFailure, new BulkResponse(new BulkItemResponse[0], 0L)); + + private Client client; + private OriginSettingClient originSettingClient; + private ResultsPersisterService resultsPersisterService; + + @Before + public void setUpTests() { + client = mock(Client.class); + originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN); + resultsPersisterService = buildResultsPersisterService(originSettingClient); + } + + public void testSearchWithRetries_ImmediateSuccess() { + doAnswer(withResponse(SEARCH_RESPONSE_SUCCESS)) + .when(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + + List messages = new ArrayList<>(); + SearchResponse searchResponse = resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> true, messages::add); + assertThat(searchResponse, is(SEARCH_RESPONSE_SUCCESS)); + assertThat(messages, is(empty())); + + verify(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + } + + public void testSearchWithRetries_SuccessAfterRetry() { + doAnswerWithResponses(SEARCH_RESPONSE_FAILURE, SEARCH_RESPONSE_SUCCESS) + .when(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + + List messages = new ArrayList<>(); + SearchResponse searchResponse = resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> true, messages::add); + assertThat(searchResponse, is(SEARCH_RESPONSE_SUCCESS)); + assertThat(messages, hasSize(1)); + + verify(client, times(2)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + } + + public void testSearchWithRetries_SuccessAfterRetryDueToException() { + doThrow(new IndexNotFoundException("my-index")).doAnswer(withResponse(SEARCH_RESPONSE_SUCCESS)) + .when(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + + List messages = new ArrayList<>(); + SearchResponse searchResponse = resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> true, messages::add); + assertThat(searchResponse, is(SEARCH_RESPONSE_SUCCESS)); + assertThat(messages, hasSize(1)); + + verify(client, times(2)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + } + + private void testSearchWithRetries_FailureAfterTooManyRetries(int maxFailureRetries) { + resultsPersisterService.setMaxFailureRetries(maxFailureRetries); + + doAnswer(withResponse(SEARCH_RESPONSE_FAILURE)) + .when(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + + List messages = new ArrayList<>(); + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> true, messages::add)); + assertThat(e.getMessage(), containsString("failed to search after [" + (maxFailureRetries + 1) + "] attempts.")); + assertThat(messages, hasSize(maxFailureRetries)); + + verify(client, times(maxFailureRetries + 1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + } + + public void testSearchWithRetries_FailureAfterTooManyRetries_0() { + testSearchWithRetries_FailureAfterTooManyRetries(0); + } + + public void testSearchWithRetries_FailureAfterTooManyRetries_1() { + testSearchWithRetries_FailureAfterTooManyRetries(1); + } + + public void testSearchWithRetries_FailureAfterTooManyRetries_10() { + testSearchWithRetries_FailureAfterTooManyRetries(10); + } + + public void testSearchWithRetries_Failure_ShouldNotRetryFromTheBeginning() { + doAnswer(withResponse(SEARCH_RESPONSE_FAILURE)) + .when(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + + List messages = new ArrayList<>(); + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> false, messages::add)); + assertThat(e.getMessage(), containsString("should not retry search after [1] attempts. SERVICE_UNAVAILABLE")); + assertThat(messages, empty()); + + verify(client, times(1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + } + + public void testSearchWithRetries_Failure_ShouldNotRetryAfterRandomNumberOfRetries() { + int maxFailureRetries = 10; + resultsPersisterService.setMaxFailureRetries(maxFailureRetries); + + doAnswer(withResponse(SEARCH_RESPONSE_FAILURE)) + .when(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + + int maxRetries = randomIntBetween(1, maxFailureRetries); + List messages = new ArrayList<>(); + ElasticsearchException e = + expectThrows( + ElasticsearchException.class, + () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, shouldRetryUntil(maxRetries), messages::add)); + assertThat( + e.getMessage(), containsString("should not retry search after [" + (maxRetries + 1) + "] attempts. SERVICE_UNAVAILABLE")); + assertThat(messages, hasSize(maxRetries)); + + verify(client, times(maxRetries + 1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any()); + } + + private static Supplier shouldRetryUntil(int maxRetries) { + return new Supplier() { + int retries = 0; + @Override + public Boolean get() { + return ++retries <= maxRetries; + } + }; + } + + public void testBulkRequestChangeOnFailures() { + doAnswerWithResponses( + new BulkResponse(new BulkItemResponse[]{BULK_ITEM_RESPONSE_FAILURE, BULK_ITEM_RESPONSE_SUCCESS}, 0L), + new BulkResponse(new BulkItemResponse[0], 0L)) + .when(client).execute(eq(BulkAction.INSTANCE), any(), any()); BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(indexRequestFail); - bulkRequest.add(indexRequestSuccess); + bulkRequest.add(INDEX_REQUEST_FAILURE); + bulkRequest.add(INDEX_REQUEST_SUCCESS); - ResultsPersisterService resultsPersisterService = buildResultsPersisterService(client); + AtomicReference lastMessage = new AtomicReference<>(); - resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, NULL_MSG_HANDLER); + resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, lastMessage::set); ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(2)).bulk(captor.capture()); + verify(client, times(2)).execute(eq(BulkAction.INSTANCE), captor.capture(), any()); List requests = captor.getAllValues(); assertThat(requests.get(0).numberOfActions(), equalTo(2)); assertThat(requests.get(1).numberOfActions(), equalTo(1)); + assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again in")); } public void testBulkRequestDoesNotRetryWhenSupplierIsFalse() { - IndexRequest indexRequestSuccess = new IndexRequest("my-index").id("success").source(Collections.singletonMap("data", "success")); - IndexRequest indexRequestFail = new IndexRequest("my-index").id("fail").source(Collections.singletonMap("data", "fail")); - BulkItemResponse successItem = new BulkItemResponse(1, - DocWriteRequest.OpType.INDEX, - new IndexResponse(new ShardId(AnomalyDetectorsIndex.jobResultsIndexPrefix() + "shared", "uuid", 1), - "_doc", - indexRequestSuccess.id(), - 0, - 0, - 1, - true)); - BulkItemResponse failureItem = new BulkItemResponse(2, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure("my-index", "_doc", "fail", new Exception("boom"))); - BulkResponse withFailure = new BulkResponse(new BulkItemResponse[]{ failureItem, successItem }, 0L); - Client client = mockClientWithResponse(withFailure, new BulkResponse(new BulkItemResponse[0], 0L)); + doAnswerWithResponses( + new BulkResponse(new BulkItemResponse[]{BULK_ITEM_RESPONSE_FAILURE, BULK_ITEM_RESPONSE_SUCCESS}, 0L), + new BulkResponse(new BulkItemResponse[0], 0L)) + .when(client).execute(eq(BulkAction.INSTANCE), any(), any()); BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(indexRequestFail); - bulkRequest.add(indexRequestSuccess); + bulkRequest.add(INDEX_REQUEST_FAILURE); + bulkRequest.add(INDEX_REQUEST_SUCCESS); - ResultsPersisterService resultsPersisterService = buildResultsPersisterService(client); + AtomicReference lastMessage = new AtomicReference<>(); expectThrows(ElasticsearchException.class, - () -> resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> false, NULL_MSG_HANDLER)); + () -> resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> false, lastMessage::set)); + verify(client, times(1)).execute(eq(BulkAction.INSTANCE), any(), any()); + + assertThat(lastMessage.get(), is(nullValue())); } public void testBulkRequestRetriesConfiguredAttemptNumber() { - IndexRequest indexRequestFail = new IndexRequest("my-index").id("fail").source(Collections.singletonMap("data", "fail")); - BulkItemResponse failureItem = new BulkItemResponse(2, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure("my-index", "_doc", "fail", new Exception("boom"))); - BulkResponse withFailure = new BulkResponse(new BulkItemResponse[]{ failureItem }, 0L); - Client client = mockClientWithResponse(withFailure); + int maxFailureRetries = 10; + resultsPersisterService.setMaxFailureRetries(maxFailureRetries); + + doAnswer(withResponse(new BulkResponse(new BulkItemResponse[]{BULK_ITEM_RESPONSE_FAILURE}, 0L))) + .when(client).execute(eq(BulkAction.INSTANCE), any(), any()); BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(indexRequestFail); + bulkRequest.add(INDEX_REQUEST_FAILURE); - ResultsPersisterService resultsPersisterService = buildResultsPersisterService(client); + AtomicReference lastMessage = new AtomicReference<>(); - resultsPersisterService.setMaxFailureRetries(1); expectThrows(ElasticsearchException.class, - () -> resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, NULL_MSG_HANDLER)); - verify(client, times(2)).bulk(any(BulkRequest.class)); + () -> resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, lastMessage::set)); + verify(client, times(maxFailureRetries + 1)).execute(eq(BulkAction.INSTANCE), any(), any()); + + assertThat(lastMessage.get(), containsString("failed to index after [10] attempts. Will attempt again in")); } public void testBulkRequestRetriesMsgHandlerIsCalled() { - IndexRequest indexRequestSuccess = new IndexRequest("my-index").id("success").source(Collections.singletonMap("data", "success")); - IndexRequest indexRequestFail = new IndexRequest("my-index").id("fail").source(Collections.singletonMap("data", "fail")); - BulkItemResponse successItem = new BulkItemResponse(1, - DocWriteRequest.OpType.INDEX, - new IndexResponse(new ShardId(AnomalyDetectorsIndex.jobResultsIndexPrefix() + "shared", "uuid", 1), - "_doc", - indexRequestSuccess.id(), - 0, - 0, - 1, - true)); - BulkItemResponse failureItem = new BulkItemResponse(2, - DocWriteRequest.OpType.INDEX, - new BulkItemResponse.Failure("my-index", "_type", "fail", new Exception("boom"))); - BulkResponse withFailure = new BulkResponse(new BulkItemResponse[]{ failureItem, successItem }, 0L); - Client client = mockClientWithResponse(withFailure, new BulkResponse(new BulkItemResponse[0], 0L)); + doAnswerWithResponses( + new BulkResponse(new BulkItemResponse[]{BULK_ITEM_RESPONSE_FAILURE, BULK_ITEM_RESPONSE_SUCCESS}, 0L), + new BulkResponse(new BulkItemResponse[0], 0L)) + .when(client).execute(eq(BulkAction.INSTANCE), any(), any()); BulkRequest bulkRequest = new BulkRequest(); - bulkRequest.add(indexRequestFail); - bulkRequest.add(indexRequestSuccess); + bulkRequest.add(INDEX_REQUEST_FAILURE); + bulkRequest.add(INDEX_REQUEST_SUCCESS); - ResultsPersisterService resultsPersisterService = buildResultsPersisterService(client); - AtomicReference msgHolder = new AtomicReference<>("not_called"); + AtomicReference lastMessage = new AtomicReference<>(); - resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, msgHolder::set); + resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, lastMessage::set); ArgumentCaptor captor = ArgumentCaptor.forClass(BulkRequest.class); - verify(client, times(2)).bulk(captor.capture()); + verify(client, times(2)).execute(eq(BulkAction.INSTANCE), captor.capture(), any()); List requests = captor.getAllValues(); assertThat(requests.get(0).numberOfActions(), equalTo(2)); assertThat(requests.get(1).numberOfActions(), equalTo(1)); - assertThat(msgHolder.get(), containsString("failed to index after [1] attempts. Will attempt again in")); + assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again in")); } - @SuppressWarnings({"unchecked", "rawtypes"}) - private Client mockClientWithResponse(BulkResponse... responses) { - Client client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - List> futures = new ArrayList<>(responses.length - 1); - ActionFuture future1 = makeFuture(responses[0]); - for (int i = 1; i < responses.length; i++) { - futures.add(makeFuture(responses[i])); - } - when(client.bulk(any(BulkRequest.class))).thenReturn(future1, futures.toArray(new ActionFuture[0])); - return client; + private static Stubber doAnswerWithResponses(Response response1, Response response2) { + return doAnswer(withResponse(response1)).doAnswer(withResponse(response2)); } @SuppressWarnings("unchecked") - private static ActionFuture makeFuture(BulkResponse response) { - ActionFuture future = mock(ActionFuture.class); - when(future.actionGet()).thenReturn(response); - return future; + private static Answer withResponse(Response response) { + return invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }; } - private ResultsPersisterService buildResultsPersisterService(Client client) { + private static ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) { + CheckedConsumer sleeper = millis -> {}; ThreadPool tp = mock(ThreadPool.class); ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS, @@ -203,6 +330,6 @@ public class ResultsPersisterServiceTests extends ESTestCase { ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING))); ClusterService clusterService = new ClusterService(Settings.EMPTY, clusterSettings, tp); - return new ResultsPersisterService(client, clusterService, Settings.EMPTY); + return new ResultsPersisterService(sleeper, client, clusterService, Settings.EMPTY); } }